diff --git a/docs/source/msei_reference/process.rst b/docs/source/msei_reference/process.rst index b004fe46..33a520c4 100644 --- a/docs/source/msei_reference/process.rst +++ b/docs/source/msei_reference/process.rst @@ -1,36 +1,7 @@ Random Process ============== -AR Processes ------------- - -.. automodule:: pyrenew.process.ar +.. automodule:: pyrenew.process :members: :undoc-members: :show-inheritance: - -First Difference (AR) ---------------------- - -.. automodule:: pyrenew.process.firstdifferencear - :members: - :undoc-members: - :show-inheritance: - -Reproduction Number Random Walk -------------------------------- - -.. automodule:: pyrenew.process.rtrandomwalk - :members: - :undoc-members: - :show-inheritance: - -Simple Random Walk ------------------- - -.. automodule:: pyrenew.process.simplerandomwalk - :members: - :undoc-members: - :show-inheritance: - -.. todo:: Determine order and naming of these modules. diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index ace4a3f7..fa03f957 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -14,9 +14,9 @@ We start by loading the needed components to build a basic renewal model: # | warning: false import jax.numpy as jnp import numpy as np -import numpyro as npro +import numpyro import numpyro.distributions as dist -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess from pyrenew.latent import ( Infections, InfectionInitializationProcess, @@ -25,10 +25,15 @@ from pyrenew.latent import ( from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import ( + RandomVariable, + DistributionalRV, + TransformedRandomVariable, +) import pyrenew.transformation as t +from numpyro.infer.reparam import LocScaleReparam -npro.set_host_device_count(2) +numpyro.set_host_device_count(2) ``` ## Architecture of `RtInfectionsRenewalModel` @@ -51,20 +56,20 @@ flowchart LR models((Model\nmetaclass)) subgraph observations[Observations module] - obs["observation_process\n(PoissonObservation)"] + obs["infection_obs_process_rv\n(PoissonObservation)"] end subgraph latent[Latent module] - inf["latent_infections\n(Infections)"] - i0["I0\n(DistributionalRV)"] + inf["latent_infections_rv\n(Infections)"] + i0["I0_rv\n(DistributionalRV)"] end subgraph process[Process module] - rt["rt_proc\n(RtRandomWalkProcess)"] + rt["Rt_process_rv\n(Custom class built using SimpleRandomWalk)"] end subgraph deterministic[Deterministic module] - detpmf["gen_int\n(DeterministicPMF)"] + detpmf["gen_int_rv\n(DeterministicPMF)"] end subgraph model[Model module] @@ -85,13 +90,13 @@ flowchart LR ``` -The pyrenew package models the real-time reproductive number $R_t$, the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components: +The pyrenew package models the real-time reproductive number $\mathcal{R}(t)$, the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components: (1) generation interval, the times between infections (2) initial infections, occurring prior to time $t = 0$ -(3) $R_t$, the real-time reproductive number, +(3) $\mathcal{R}(t)$, the time-varying reproductive number, (4) latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and @@ -103,7 +108,7 @@ To initialize these five components within the renewal modeling framework, we es (2) an instance of the `InfectionInitializationProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `InitializeInfectionsZeroPad`, the latent infections before this time are assumed to be 0. -(3) an instance of the `RtRandomWalkProcess` class with default values +(3) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom `RandomVariable`, `MyRt`. (4) an instance of the `Infections` class with default values, and @@ -112,23 +117,47 @@ To initialize these five components within the renewal modeling framework, we es ```{python} # | label: creating-elements # (1) The generation interval (deterministic) -pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) +pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1]) gen_int = DeterministicPMF(pmf_array, name="gen_int") # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) -# (3) The random process for Rt -rt_proc = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), -) + +# (3) The random walk on log Rt, with an inferred s.d. Here, we +# construct a custom RandomVariable. +class MyRt(RandomVariable): + + def validate(self): + pass + + def sample(self, n_steps: int, **kwargs) -> tuple: + sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) + + rt_rv = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV( + dist.Normal(0, sd_rt), + "rw_step_rv", + reparam=LocScaleReparam(0), + ), + init_rv=DistributionalRV( + dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv" + ), + ), + transforms=t.ExpTransform(), + ) + return rt_rv.sample(n_steps=n_steps, **kwargs) + + +rt_proc = MyRt() # (4) Latent infection process (which will use 1 and 2) latent_infections = Infections() @@ -158,7 +187,7 @@ The following diagram summarizes how the modules interact via composition; notab flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionInitializationProcess)"] - rt["(3) rt_proc\n(RtRandomWalkProcess)"] + rt["(3) rt_proc\n(MyRt, the custom RV defined above)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] @@ -175,14 +204,13 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R ```{python} # | label: simulate -np.random.seed(223) -with npro.handlers.seed(rng_seed=np.random.randint(1, 60)): - sim_data = model1.sample(n_timepoints_to_simulate=30) +with numpyro.handlers.seed(rng_seed=53): + sim_data = model1.sample(n_timepoints_to_simulate=40) sim_data ``` -To understand what has been accomplished here, visualize an $R_t$ sample path (left panel) and infections over time (right panel): +To understand what has been accomplished here, visualize an $\mathcal{R}(t)$ sample path (left panel) and infections over time (right panel): ```{python} # | label: fig-basic @@ -220,7 +248,7 @@ model1.run( ) ``` -Now, let's investigate the output, particularly the posterior distribution of the $R_t$ estimates: +Now, let's investigate the output, particularly the posterior distribution of the $\mathcal{R}(t)$ estimates: ```{python} # | label: fig-output-rt @@ -243,19 +271,19 @@ import arviz as az idata = az.from_numpyro(model1.mcmc) ``` -and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $R_t$. +and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $\mathcal{R}(t)$. ```{python} # | label: diagnostics diagnostic_stats_summary = az.summary( - idata.posterior["Rt"], + idata.posterior["Rt"][::, ::, 4:], # ignore nan padding kind="diagnostics", ) -print(diagnostic_stats_summary[:10]) +print(diagnostic_stats_summary) ``` -Below we use `plot_trace` to inspect the trace of the first 10 $R_t$ estimates. +Below we use `plot_trace` to inspect the trace of the first 10 inferred $\mathcal{R}(t)$ values. ```{python} # | label: fig-trace-Rt @@ -265,20 +293,20 @@ plt.rcParams["figure.constrained_layout.use"] = True az.plot_trace( idata.posterior, var_names=["Rt"], - coords={"Rt_dim_0": np.arange(10)}, + coords={"Rt_dim_0": np.arange(4, 14)}, compact=False, ) plt.show() ``` -We inspect the posterior distribution of $R_t$ by plotting the 90% and 50% highest density intervals: +We inspect the posterior distribution of $\mathcal{R}(t)$ by plotting the 90% and 50% highest density intervals: ```{python} # | label: fig-hdi-Rt # | fig-cap: High density interval for Effective Reproduction Number -x_data = idata.posterior["Rt_dim_0"] -y_data = idata.posterior["Rt"] +x_data = idata.posterior["Rt_dim_0"][4:] +y_data = idata.posterior["Rt"][::, ::, 4:] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( @@ -300,12 +328,12 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_Rt = np.mean(idata.posterior["Rt"], axis=1) -axes.plot(x_data, mean_Rt[0], color="C0", label="Mean") +median_ts = y_data.median(dim=["chain", "draw"]) +axes.plot(x_data, median_ts, color="C0", label="Median") axes.legend() axes.set_title("Posterior Effective Reproduction Number", fontsize=10) axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("$R_t$", fontsize=10) +axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10) plt.show() ``` @@ -338,11 +366,10 @@ az.plot_hdi( ax=axes, ) -# Add mean of the posterior to the figure -mean_latent_infection = np.mean( - idata.posterior["all_latent_infections"], axis=1 -) -axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") +# plot the posterior median +median_ts = y_data.median(dim=["chain", "draw"]) +axes.plot(x_data, median_ts, color="C0", label="Median") + axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) axes.set_xlabel("Time", fontsize=10) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 468bd1b9..65d51052 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -28,8 +28,8 @@ import numpyro.distributions as dist from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import InfectionsWithFeedback from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.process import RtRandomWalkProcess -from pyrenew.metaclass import DistributionalRV +from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsExponentialGrowth, @@ -60,10 +60,14 @@ latent_infections = InfectionsWithFeedback( infection_feedback_pmf=gen_int, ) -rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) ``` diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 83299f1d..7e89df33 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -7,16 +7,16 @@ engine: jupyter ```{python} # | label: numpyro setup # | echo: false -import numpyro as npro +import numpyro -npro.set_host_device_count(2) +numpyro.set_host_device_count(2) ``` This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data. ## Model definition -In this section, we provide the formal definition of the model. The hospitalization model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: +In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: $$ h(t) \sim \text{HospDist}\left(H(t)\right) @@ -33,9 +33,9 @@ H(t) & = p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) \\ \end{align*} $$ -Were $d(\tau)$ is the infection to hospitalization interval, $I(t)$ is the number of latent infections at time $t$, $p_\mathrm{hosp}(t)$ is the infection to hospitalization rate. +Were $d(\tau)$ is the infection to hospital admission interval, $I(t)$ is the number of latent infections at time $t$, $p_\mathrm{hosp}(t)$ is the infection to admission rate. -The number of latent hospital admissions at time $t$ is a function of the number of latent infections at time $t$ and the infection to hospitalization rate. The latent infections are modeled as a renewal process: +The number of latent hospital admissions at time $t$ is a function of the number of latent infections at time $t$ and the infection to admission rate. The latent infections are modeled as a renewal process: $$ \begin{align*} @@ -104,11 +104,11 @@ plt.show() ## Building the model -First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospitalization interval. +First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospital admission interval. ```{python} # | label: fig-data-extract -# | fig-cap: Generation interval and infection to hospitalization interval +# | fig-cap: Generation interval and infection to hospital admission interval gen_int = datasets.load_generation_interval() inf_hosp_int = datasets.load_infection_admission_interval() @@ -117,7 +117,7 @@ gen_int_array = gen_int["probability_mass"].to_numpy() gen_int = gen_int_array inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() -# Taking a pick at the first 5 elements of each +# Taking a peek at the first 5 elements of each gen_int[:5], inf_hosp_int[:5] # Visualizing both quantities side by side @@ -126,7 +126,7 @@ fig, axs = plt.subplots(1, 2) axs[0].plot(gen_int) axs[0].set_title("Generation interval") axs[1].plot(inf_hosp_int) -axs[1].set_title("Infection to hospitalization interval") +axs[1].set_title("Infection to hospital admission interval") plt.show() ``` @@ -153,7 +153,7 @@ latent_hosp = latent.HospitalAdmissions( ) ``` -The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospitalization interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospitalization rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospitalization rate. Now, we can define the rest of the other components: +The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: ```{python} # | label: initializing-rest-of-model @@ -163,6 +163,7 @@ from pyrenew.latent import ( InitializeInfectionsExponentialGrowth, ) + # Infection process latent_inf = latent.Infections() I0 = InfectionInitializationProcess( @@ -179,11 +180,34 @@ I0 = InfectionInitializationProcess( # Generation interval and Rt gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") -rtproc = process.RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=transformation.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), -) + + +class MyRt(metaclass.RandomVariable): + + def validate(self): + pass + + def sample(self, n_steps: int, **kwargs) -> tuple: + sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) + + rt_rv = metaclass.TransformedRandomVariable( + "Rt_rv", + base_rv=process.SimpleRandomWalkProcess( + name="log_rt", + step_rv=metaclass.DistributionalRV( + dist.Normal(0, sd_rt), "rw_step_rv" + ), + init_rv=metaclass.DistributionalRV( + dist.Normal(0, 0.2), "init_log_Rt_rv" + ), + ), + transforms=transformation.ExpTransform(), + ) + + return rt_rv.sample(n_steps=n_steps, **kwargs) + + +rtproc = MyRt() # The observation model @@ -226,7 +250,8 @@ import numpy as np timeframe = 120 -with npro.handlers.seed(rng_seed=223): + +with numpyro.handlers.seed(rng_seed=223): simulated_data = hosp_model.sample(n_timepoints_to_simulate=timeframe) ``` @@ -242,9 +267,8 @@ axs[0].plot(simulated_data.Rt) axs[0].set_ylabel("Simulated Rt") # Admissions plot -axs[1].plot(simulated_data.observed_hosp_admissions) +axs[1].plot(simulated_data.observed_hosp_admissions, "-o") axs[1].set_ylabel("Simulated Admissions") -axs[1].set_yscale("log") fig.suptitle("Basic renewal model") fig.supxlabel("Time") @@ -319,6 +343,7 @@ Below we plot 90% and 50% highest density intervals for latent hospital admissio x_data = idata.posterior["latent_hospital_admissions_dim_0"] y_data = idata.posterior["latent_hospital_admissions"] + fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, @@ -340,11 +365,10 @@ az.plot_hdi( ax=axes, ) -# Add mean of the posterior to the figure -mean_latent_hosp_admission = np.mean( - idata.posterior["latent_hospital_admissions"], axis=1 -) -axes.plot(x_data, mean_latent_hosp_admission[0], color="C0", label="Mean") +# Add the posterior median to the figure +median_ts = y_data.median(dim=["chain", "draw"]) + +axes.plot(x_data, median_ts, color="C0", label="Median") axes.legend() axes.set_title("Posterior Hospital Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) @@ -390,6 +414,11 @@ az.plot_hdi( fill_kwargs={"alpha": 0.6}, ax=axes, ) + +# Add the posterior median to the figure +median_ts = y_data.median(dim=["chain", "draw"]) +axes.plot(x_data, median_ts, color="C0", label="Median") +axes.legend() ``` @@ -499,12 +528,14 @@ And now we plot the posterior predictive distributions with a `{python} n_foreca ```{python} # | label: fig-output-posterior-predictive-forecast # | fig-cap: Posterior predictive admissions, including a forecast. +x_data = ( + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size() +) +y_data = idata_weekday.posterior_predictive["negbinom_rv"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - hdi_data=compute_eti( - idata_weekday.posterior_predictive["negbinom_rv"], 0.9 - ), + x_data, + hdi_data=compute_eti(y_data, 0.9), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, @@ -512,26 +543,22 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - hdi_data=compute_eti( - idata_weekday.posterior_predictive["negbinom_rv"], 0.5 - ), + x_data, + hdi_data=compute_eti(y_data, 0.5), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) -# Add mean of the posterior to the figure -mean_latent_infection = np.mean( - idata_weekday.posterior_predictive["negbinom_rv"], axis=1 -) +# Add median of the posterior to the figure +median_ts = y_data.median(dim=["chain", "draw"]) plt.plot( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - mean_latent_infection[0], + x_data, + median_ts, color="C0", - label="Mean", + label="Median", ) plt.scatter( idata_weekday.observed_data["negbinom_rv_dim_0"] + gen_int.size(), diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index d7444d70..598308c8 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -38,17 +38,22 @@ numpyro.set_host_device_count(2) ```{python} from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV ``` -To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. +To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the block that follows. Inside the `with` block, the `q_samp = q(n_steps=100)` generates the sample instance over a `n_steps` period of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. ```{python} # | label: fig-randwalk # | fig-cap: Random walk example -np.random.seed(3312) -q = SimpleRandomWalkProcess(dist.Normal(0, 0.001)) -with seed(rng_seed=np.random.randint(0, 1000)): - q_samp = q(n_timepoints=100) +q = SimpleRandomWalkProcess( + "example_random_walk", + step_rv=DistributionalRV(dist.Normal(0, 0.001), "step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.001), "init_rv"), +) + +with seed(rng_seed=325): + q_samp = q(n_steps=100) plt.plot(np.exp(q_samp[0])) ``` @@ -60,7 +65,6 @@ from pyrenew.latent import ( Infections, HospitalAdmissions, ) -from pyrenew.metaclass import DistributionalRV ``` Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal model (Rt) random walk process: @@ -69,11 +73,11 @@ Additionally, import several classes from Pyrenew, including a Poisson observati from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.model import HospitalAdmissionsModel -from pyrenew.process import RtRandomWalkProcess from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsZeroPad, ) +from pyrenew.metaclass import TransformedRandomVariable import pyrenew.transformation as t ``` @@ -87,9 +91,10 @@ To initialize the model, we first define initial conditions, including: 4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05. -5) hospitalization observation process, modeled with a Poisson distribution +5) hospitalization observation process, modeled with a Poisson distribution + +6) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. -6) an Rt random walk process with default settings ```{python} # Initializing model components: @@ -129,12 +134,15 @@ latent_admissions = HospitalAdmissions( # 5) An observation process for the hospital admissions admissions_process = PoissonObservation("poisson_rv") -# 6) A random walk process (it could be deterministic using -# pyrenew.process.DeterministicProcess()) -Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +# 6) The random walk on log Rt +Rt_process = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) ``` diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 9f868f9a..cc7418c2 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -5,18 +5,20 @@ """ from abc import ABCMeta, abstractmethod -from typing import NamedTuple, get_type_hints +from typing import get_type_hints import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np -import numpyro as npro +import numpyro import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive +from numpyro.infer.reparam import Reparam from pyrenew.mcmcutils import plot_posterior, spread_draws +from pyrenew.transformation import Transform def _assert_sample_and_rtype( @@ -209,45 +211,33 @@ def __call__(self, **kwargs): return self.sample(**kwargs) -class DistributionalRVSample(NamedTuple): - """ - Named tuple for the sample method of DistributionalRV - - Attributes - ---------- - value : ArrayLike - Sampled value from the distribution. - """ - - value: ArrayLike | None = None - - def __repr__(self) -> str: - """ - Representation of the DistributionalRVSample - """ - return f"DistributionalRVSample(value={self.value})" - - class DistributionalRV(RandomVariable): """ - Wrapper class for random variables that sample from a single `numpyro.distributions.Distribution`. + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution`. """ def __init__( self, - dist: npro.distributions.Distribution, + dist: numpyro.distributions.Distribution, name: str, - ): + reparam: Reparam = None, + ) -> None: """ Default constructor for DistributionalRV. Parameters ---------- - dist : npro.distributions.Distribution + dist : numpyro.distributions.Distribution Distribution of the random variable. name : str Name of the random variable. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + Returns ------- None @@ -257,6 +247,10 @@ def __init__( self.dist = dist self.name = name + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} return None @@ -265,7 +259,7 @@ def validate(dist: any) -> None: """ Validation of the distribution to be implemented in subclasses. """ - if not isinstance(dist, npro.distributions.Distribution): + if not isinstance(dist, numpyro.distributions.Distribution): raise ValueError( "dist should be an instance of " f"numpyro.distributions.Distribution, got {dist}" @@ -277,31 +271,31 @@ def sample( self, obs: ArrayLike | None = None, **kwargs, - ) -> DistributionalRVSample: + ) -> tuple: """ Sample from the distribution. Parameters ---------- obs : ArrayLike, optional - Observations passed as the `obs` argument to `numpyro.sample()`. Default `None`. + Observations passed as the `obs` argument to + :fun:`numpyro.sample()`. Default `None`. **kwargs : dict, optional - Additional keyword arguments passed through to internal sample calls, - should there be any. + Additional keyword arguments passed through + to internal sample calls, should there be any. Returns ------- - DistributionalRVSample - """ - return DistributionalRVSample( - value=jnp.atleast_1d( - npro.sample( - name=self.name, - fn=self.dist, - obs=obs, - ) - ), - ) + tuple + Containing the sampled from the distribution. + """ + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=self.dist, + obs=obs, + ) + return (jnp.atleast_1d(sample),) class Model(metaclass=ABCMeta): @@ -416,9 +410,13 @@ def run( Parameters ---------- nuts_args : dict, optional - Dictionary of arguments passed to the NUTS. Defaults to None. + Dictionary of arguments passed to the + :class:`numpyro.infer.NUTS` kernel. + Defaults to None. mcmc_args : dict, optional - Dictionary of passed to the MCMC sampler. Defaults to None. + Dictionary of arguments passed to the + :class:`numpyro.infer.MCMC` constructor. + Defaults to None. Returns ------- @@ -447,14 +445,14 @@ def print_summary( exclude_deterministic: bool = True, ) -> None: """ - A wrapper of MCMC.print_summary + A wrapper of :meth:`numpyro.infer.MCMC.print_summary` Parameters ---------- prob : float, optional - The acceptance probability of print_summary. Defaults to 0.9 + The width of the credible interval to show. Default 0.9 exclude_deterministic : bool, optional - Whether to print deterministic variables in the summary. + Whether to print deterministic sites in the summary. Defaults to True. Returns @@ -510,16 +508,19 @@ def posterior_predictive( **kwargs, ) -> dict: """ - A wrapper for numpyro.infer.Predictive to generate posterior predictive samples. + A wrapper for :class:`numpyro.infer.Predictive` to generate + posterior predictive samples. Parameters ---------- rng_key : ArrayLike, optional Random key for the Predictive function call. Defaults to None. numpyro_predictive_args : dict, optional - Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor. + Dictionary of arguments to be passed to the + :class:`numpyro.inference.Predictive` constructor. **kwargs - Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive + Additional named arguments passed to the + `__call__()` method of :class:`numpyro.infer.Predictive` Returns ------- @@ -580,3 +581,119 @@ def prior_predictive( ) return predictive(rng_key, **kwargs) + + +class TransformedRandomVariable(RandomVariable): + """ + Class to represent RandomVariables defined + by taking the output of another RV's + :meth:`RandomVariable.sample()` method + and transforming it by a given transformation + (typically a :class:`Transform`) + """ + + def __init__( + self, + name: str, + base_rv: RandomVariable, + transforms: Transform | tuple[Transform], + ): + """ + Default constructor + + Parameters + ---------- + + name : str + A name for the random variable instance + + base_rv : RandomVariable + The underlying (untransformed) RandomVariable + + transforms : Transform + Transformation or tuple of transformations + to apply to the output of + `base_rv.sample()`; single values will be coerced to + a length-one tuple. If a tuple, should be the same + length as the tuple returned by `base_rv.sample()` + + Returns + ------- + None + """ + self.name = name + self.base_rv = base_rv + + if not isinstance(transforms, tuple): + transforms = (transforms,) + self.transforms = transforms + self.validate() + + def sample(self, **kwargs) -> tuple: + """ + Sample method. Call self.base_rv.sample() + and then apply the transforms specified + in self.transforms. + + Parameters + ---------- + **kwargs : + Keyword arguments passed to self.base_rv.sample() + + Returns + ------- + tuple of the same length as the tuple returned by + self.base_rv.sample() + """ + + untransformed_values = self.base_rv.sample(**kwargs) + + return tuple( + t(uv) for t, uv in zip(self.transforms, untransformed_values) + ) + + def sample_length(self): + """ + Sample length for a transformed + random variable must be equal to the + length of self.transforms or + validation will fail. + + Returns + ------- + int + Equal to the length self.transforms + """ + return len(self.transforms) + + def validate(self): + """ + Perform validation checks on a + TransformedRandomVariable instance, + confirming that all transformations + are callable and that the number of + transformations is equal to the sample + length of the base random variable. + + Returns + ------- + None + on successful validation, or raise a ValueError + """ + for t in self.transforms: + if not callable(t): + raise ValueError( + "All entries in self.transforms " "must be callable" + ) + if hasattr(self.base_rv, "sample_length"): + n_transforms = len(self.transforms) + n_entries = self.base_rv.sample_length() + if not n_transforms == n_entries: + raise ValueError( + "There must be exactly as many transformations " + "specified as entries self.transforms as there are " + "entries in the tuple returned by " + "self.base_rv.sample()." + f"Got {n_transforms} transforms and {n_entries} " + "entries" + ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 8a80fbc1..8db8bd6e 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -6,7 +6,7 @@ from typing import NamedTuple import jax.numpy as jnp -import numpyro as npro +import numpyro import pyrenew.arrayutils as au from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation @@ -164,7 +164,8 @@ def sample( Notes ----- - Either `data_observed_infections` or `n_timepoints_to_simulate` must be specified, not both. + Either `data_observed_infections` or `n_timepoints_to_simulate` + must be specified, not both. Returns ------- @@ -193,7 +194,7 @@ def sample( # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) Rt, *_ = self.Rt_process_rv( - n_timepoints=n_timepoints, + n_steps=n_timepoints, **kwargs, ) @@ -222,7 +223,7 @@ def sample( all_latent_infections = jnp.hstack( [I0, post_initialization_latent_infections] ) - npro.deterministic("all_latent_infections", all_latent_infections) + numpyro.deterministic("all_latent_infections", all_latent_infections) if observed_infections is not None: observed_infections = au.pad_x_to_match_y( @@ -238,6 +239,7 @@ def sample( jnp.nan, pad_direction="start", ) + numpyro.deterministic("Rt", Rt) return RtInfectionsRenewalSample( Rt=Rt, diff --git a/model/src/pyrenew/process/__init__.py b/model/src/pyrenew/process/__init__.py index 1613f168..bad08343 100644 --- a/model/src/pyrenew/process/__init__.py +++ b/model/src/pyrenew/process/__init__.py @@ -9,13 +9,11 @@ RtPeriodicDiffProcess, RtWeeklyDiffProcess, ) -from pyrenew.process.rtrandomwalk import RtRandomWalkProcess from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess __all__ = [ "ARProcess", "FirstDifferenceARProcess", - "RtRandomWalkProcess", "SimpleRandomWalkProcess", "RtPeriodicDiffProcess", "RtWeeklyDiffProcess", diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py deleted file mode 100644 index 5722fd54..00000000 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- -# numpydoc ignore=GL08 - -import numpyro as npro -import numpyro.distributions as dist -import pyrenew.transformation as t -from pyrenew.metaclass import RandomVariable -from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess - - -class RtRandomWalkProcess(RandomVariable): - r"""Rt Randomwalk Process - - Notes - ----- - - The process is defined as follows: - - .. math:: - - Rt(0) &\sim \text{Rt0_dist} \\ - Rt(t) &\sim \text{Rt_transform}(\text{Rt_transformed_rw}(t)) - """ - - def __init__( - self, - Rt0_dist: dist.Distribution, - Rt_rw_dist: dist.Distribution, - Rt_transform: t.Transform | None = None, - ) -> None: - """ - Default constructor - - Parameters - ---------- - Rt0_dist : dist.Distribution - Initial distribution of Rt. - Rt_rw_dist : dist.Distribution - Randomwalk process. - Rt_transform : numpyro.distributions.transformers.Transform, optional - Transformation applied to the sampled Rt0. If None, the identity - transformation is used. - - Returns - ------- - None - """ - if Rt_transform is None: - Rt_transform = t.IdentityTransform() - - RtRandomWalkProcess.validate(Rt0_dist, Rt_transform, Rt_rw_dist) - - self.Rt0_dist = Rt0_dist - self.Rt_transform = Rt_transform - self.Rt_rw_dist = Rt_rw_dist - - return None - - @staticmethod - def validate( - Rt0_dist: dist.Distribution, - Rt_transform: t.Transform, - Rt_rw_dist: dist.Distribution, - ) -> None: - """ - Validates Rt0_dist, Rt_transform, and Rt_rw_dist. - - Parameters - ---------- - Rt0_dist : dist.Distribution, optional - Initial distribution of Rt, expected dist.Distribution - Rt_transform : numpyro.distributions.transforms.Transform - Transformation applied to the sampled Rt0. - Rt_rw_dist : any - Randomwalk process, expected dist.Distribution. - - Returns - ------- - None - - Raises - ------ - AssertionError - If Rt0_dist or Rt_rw_dist are not dist.Distribution or if - Rt_transform is not numpyro.distributions.transforms.Transform. - """ - assert isinstance(Rt0_dist, dist.Distribution) - assert isinstance(Rt_transform, t.Transform) - assert isinstance(Rt_rw_dist, dist.Distribution) - - def sample( - self, - n_timepoints: int, - **kwargs, - ) -> tuple: - """ - Generate samples from the process - - Parameters - ---------- - n_timepoints : int - Number of timepoints to sample. - **kwargs : dict, optional - Additional keyword arguments passed through to internal sample() - calls, should there be any. - - Returns - ------- - tuple - With a single array of shape (n_timepoints,). - """ - - Rt0 = npro.sample("Rt0", self.Rt0_dist) - - Rt0_trans = self.Rt_transform(Rt0) - Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist) - Rt_trans_ts, *_ = Rt_trans_proc( - n_timepoints=n_timepoints, - name="Rt_transformed_rw", - init=Rt0_trans, - ) - - Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts)) - - return (Rt,) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index c18bec67..cc396192 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -2,8 +2,6 @@ # numpydoc ignore=GL08 import jax.numpy as jnp -import numpyro as npro -import numpyro.distributions as dist from numpyro.contrib.control_flow import scan from pyrenew.metaclass import RandomVariable @@ -12,45 +10,58 @@ class SimpleRandomWalkProcess(RandomVariable): """ Class for a Markovian random walk with an a - arbitrary step distribution + step distribution """ def __init__( self, - error_distribution: dist.Distribution, + name: str, + step_rv: RandomVariable, + init_rv: RandomVariable, + t_start: int = None, + t_unit: int = None, ) -> None: """ Default constructor Parameters ---------- - error_distribution : dist.Distribution - Passed to numpyro.sample. + name : str + A name for the random variable, used to + name sites within it in :fun :`numpyro.sample()` + calls. + step_rv : RandomVariable + RandomVariable representing the step distribution. + init_rv : RandomVariable + RandomVariable representing the initial value of + the process + t_start : int + See :class:`RandomVariable` + t_unit : int + See :class:`RandomVariable` Returns ------- None """ - self.error_distribution = error_distribution + self.name = name + self.step_rv = step_rv + self.init_rv = init_rv + self.t_start = t_start + self.t_unit = t_unit def sample( self, - n_timepoints: int, - name: str = "randomwalk", - init: float = None, + n_steps: int, **kwargs, ) -> tuple: """ - Samples from the randomwalk + Sample from the random walk. Parameters ---------- - n_timepoints : int - Length of the walk. - name : str, optional - Passed to numpyro.sample, by default "randomwalk" - init : float, optional - Initial point of the walk, by default None + n_steps : int + Length of the walk to sample. **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. @@ -58,29 +69,29 @@ def sample( Returns ------- tuple - With a single array of shape (n_timepoints,). + With a single array of shape (n_steps,). """ - if init is None: - init = npro.sample(name + "_init", self.error_distribution) + init, *_ = self.init_rv(**kwargs) def transition(x_prev, _): # numpydoc ignore=GL08 - diff = npro.sample(name + "_diffs", self.error_distribution) + diff, *_ = self.step_rv(**kwargs) x_curr = x_prev + diff return x_curr, x_curr _, x = scan( transition, init=init, - xs=jnp.arange(n_timepoints - 1), + xs=jnp.arange(n_steps - 1), ) - return (jnp.hstack([init, x]),) + return (jnp.hstack([init, x.flatten()]),) @staticmethod def validate(): """ - Validates inputted parameters, implementation pending. + Validates input parameters, implementation pending. """ + super().validate() return None diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 523297b3..02255544 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -13,10 +13,10 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess def test_forecast(): @@ -31,11 +31,16 @@ def test_forecast(): ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 6d5abea1..9033e9bf 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -9,8 +9,8 @@ from pyrenew import transformation as t from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalRV -from pyrenew.process import RtRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.process import SimpleRandomWalkProcess def test_admissions_sample(): @@ -22,13 +22,18 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions np.random.seed(223) - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_rt, *_ = rt(n_timepoints=30) + sim_rt, *_ = rt(n_steps=30) gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) i0 = 10 * jnp.ones_like(gen_int) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 0c9be7d4..f330464f 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -9,7 +9,8 @@ import pyrenew.transformation as t import pytest from pyrenew.latent import Infections -from pyrenew.process import RtRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.process import SimpleRandomWalkProcess def test_infections_as_deterministic(): @@ -19,13 +20,18 @@ def test_infections_as_deterministic(): """ np.random.seed(223) - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_rt, *_ = rt(n_timepoints=30) + sim_rt, *_ = rt(n_steps=30) gen_int = jnp.array([0.25, 0.25, 0.25, 0.25]) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index d0961eb3..ba66cda3 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -16,15 +16,39 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess + + +def get_default_rt(): + """ + Helper function to create a default Rt + RandomVariable for this testing session. + + Returns + ------- + TransformedRandomVariable : + A log-scale random walk with fixed + init value and step size priors + """ + return TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), + ) def test_model_basicrenewal_no_timepoints_or_observations(): """ - Test that the basic renewal model does not run without either n_timepoints_to_simulate or observed_admissions + Test that the basic renewal model does not run + without either n_timepoints_to_simulate or + observed_admissions """ gen_int = DeterministicPMF( @@ -37,11 +61,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -74,11 +94,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): observed_infections = PoissonObservation("possion_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -119,11 +135,7 @@ def test_model_basicrenewal_no_obs_model(): latent_infections = Infections() - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model0 = RtInfectionsRenewalModel( gen_int_rv=gen_int, @@ -197,11 +209,7 @@ def test_model_basicrenewal_with_obs_model(): observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -251,11 +259,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hosp_admissions.py similarity index 87% rename from model/src/test/test_model_hospitalizations.py rename to model/src/test/test_model_hosp_admissions.py index 7eb6a8c4..1abb2a3b 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import jax.random as jr import numpy as np -import numpyro as npro +import numpyro import numpyro.distributions as dist import polars as pl import pytest @@ -21,10 +21,36 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import ( + DistributionalRV, + RandomVariable, + TransformedRandomVariable, +) from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess + + +def get_default_rt(): + """ + Helper function to create a default Rt + RandomVariable for this testing session. + + Returns + ------- + TransformedRandomVariable : + A log-scale random walk with fixed + init value and step size priors + """ + return TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), + ) class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 @@ -39,13 +65,16 @@ def validate(self): # numpydoc ignore=GL08 def sample(self, **kwargs): # numpydoc ignore=GL08 return ( - npro.sample(name=self.name, fn=dist.Uniform(high=0.99, low=0.01)), + numpyro.sample( + name=self.name, fn=dist.Uniform(high=0.99, low=0.01) + ), ) def test_model_hosp_no_timepoints_or_observations(): """ - Checks that the Hospitalization model does not run without either n_timepoints_to_simulate or observed_admissions + Checks that the hospital admissions model does not run + without either n_timepoints_to_simulate or observed_admissions """ gen_int = DeterministicPMF( @@ -55,11 +84,8 @@ def test_model_hosp_no_timepoints_or_observations(): I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -104,8 +130,7 @@ def test_model_hosp_no_timepoints_or_observations(): hosp_admission_obs_process_rv=observed_admissions, ) - np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=233): with pytest.raises(ValueError, match="Either"): model1.sample( n_timepoints_to_simulate=None, data_observed_admissions=None @@ -114,7 +139,8 @@ def test_model_hosp_no_timepoints_or_observations(): def test_model_hosp_both_timepoints_and_observations(): """ - Checks that the Hospitalization model does not run with both n_timepoints_to_simulate and observed_admissions passed + Checks that the hospital admissions model does not run with + both n_timepoints_to_simulate and observed_admissions passed """ gen_int = DeterministicPMF( @@ -124,11 +150,8 @@ def test_model_hosp_both_timepoints_and_observations(): I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -174,7 +197,7 @@ def test_model_hosp_both_timepoints_and_observations(): ) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): with pytest.raises(ValueError, match="Cannot pass both"): model1.sample( n_timepoints_to_simulate=30, @@ -200,11 +223,8 @@ def test_model_hosp_no_obs_model(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + inf_hosp = DeterministicPMF( jnp.array( [ @@ -250,13 +270,13 @@ def test_model_hosp_no_obs_model(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model0_samp = model0.sample(n_timepoints_to_simulate=30) model0.hosp_admission_obs_process_rv = NullObservation() np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) np.testing.assert_array_almost_equal(model0_samp.Rt, model1_samp.Rt) @@ -310,11 +330,7 @@ def test_model_hosp_with_obs_model(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -361,7 +377,7 @@ def test_model_hosp_with_obs_model(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample(n_timepoints_to_simulate=30) model1.run( @@ -400,11 +416,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -462,7 +474,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample(n_timepoints_to_simulate=30) model1.run( @@ -503,11 +515,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -575,7 +584,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample( n_timepoints_to_simulate=n_obs_to_generate, padding=pad_size ) diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 1089974e..d98269a1 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -14,12 +14,12 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess -pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) +pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(pmf_array, name="gen_int") I0 = InfectionInitializationProcess( "I0_initialization", @@ -29,11 +29,16 @@ ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") -rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index 5f1c9986..6a44e4ff 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import jax.random as jr import numpy as np -import numpyro as npro +import numpyro import numpyro.distributions as dist import pyrenew.transformation as t from numpy.testing import assert_array_equal, assert_raises @@ -18,10 +18,10 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess def create_test_model(): # numpydoc ignore=GL08 @@ -35,10 +35,14 @@ def create_test_model(): # numpydoc ignore=GL08 ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) model = RtInfectionsRenewalModel( I0_rv=I0, @@ -99,7 +103,7 @@ def test_rng_keys_produce_correct_samples(): ] # sample only a single model and use that model's samples # as the observed_infections for the rest of the models - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model_sample = models[0].sample( n_timepoints_to_simulate=n_timepoints_to_simulate[0] ) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index c2dcb186..66be96db 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -4,6 +4,8 @@ import numpyro import numpyro.distributions as dist from numpy.testing import assert_almost_equal +from pyrenew.deterministic import DeterministicVariable +from pyrenew.metaclass import DistributionalRV from pyrenew.process import SimpleRandomWalkProcess @@ -12,16 +14,32 @@ def test_rw_can_be_sampled(): Check that a simple random walk can be initialized and sampled from """ - rw_normal = SimpleRandomWalkProcess(dist.Normal(0, 1)) + init_rv_rand = DistributionalRV(dist.Normal(1, 0.5), "init_rv_rand") + init_rv_fixed = DeterministicVariable(50.0, "init_rv_fixed") + + step_rv = DistributionalRV(dist.Normal(0, 1), "rw_step") + + rw_init_rand = SimpleRandomWalkProcess( + "rw_rand_init", step_rv=step_rv, init_rv=init_rv_rand + ) + + rw_init_fixed = SimpleRandomWalkProcess( + "rw_fixed_init", step_rv=step_rv, init_rv=init_rv_fixed + ) with numpyro.handlers.seed(rng_seed=62): - # can sample with and without inits - ans0 = rw_normal(n_timepoints=3532, init=50.0) - ans1 = rw_normal(n_timepoints=5023) + # can sample with a fixed init + # and with a random init + ans_rand = rw_init_rand(n_steps=3532) + ans_fixed = rw_init_fixed(n_steps=5023) - # check that the samples are of the right shape - assert ans0[0].shape == (3532,) - assert ans1[0].shape == (5023,) + # check that the samples are of the right shape + assert ans_rand[0].shape == (3532,) + assert ans_fixed[0].shape == (5023,) + + # check that fixing inits works + assert_almost_equal(ans_fixed[0][0], init_rv_fixed.vars) + assert ans_rand[0][0] != init_rv_fixed.vars def test_rw_samples_correctly_distributed(): @@ -34,10 +52,18 @@ def test_rw_samples_correctly_distributed(): for step_mean, step_sd in zip( [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): - rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - rw_init = 532.0 + rw_init_val = 532.0 + rw_normal = SimpleRandomWalkProcess( + name="rw_normal_test", + step_rv=DistributionalRV( + dist=dist.Normal(loc=step_mean, scale=step_sd), + name="rw_normal_dist", + ), + init_rv=DeterministicVariable(rw_init_val, "init_rv_fixed"), + ) + with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal(n_timepoints=n_samples, init=rw_init) + samples, *_ = rw_normal(n_steps=n_samples) # Checking the shape assert samples.shape == (n_samples,) @@ -60,4 +86,4 @@ def test_rw_samples_correctly_distributed(): assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) # first value should be the init value - assert_almost_equal(samples[0], rw_init) + assert_almost_equal(samples[0], rw_init_val) diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py new file mode 100644 index 00000000..cf52b487 --- /dev/null +++ b/model/src/test/test_transformed_rv_class.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +""" +Tests for TransformedRandomVariable class +""" + +import numpyro +import numpyro.distributions as dist +import pyrenew.transformation as t +import pytest +from numpy.testing import assert_almost_equal +from pyrenew.metaclass import ( + DistributionalRV, + RandomVariable, + TransformedRandomVariable, +) + + +class LengthTwoRV(RandomVariable): + """ + Class for a RandomVariable + with sample_length 2 + and values 1 and 5 + """ + + def sample(self, **kwargs): + """ + Deterministic sampling method + that returns a length-2 tuple + + Returns + ------- + tuple + (1, 5) + """ + return (1, 5) + + def sample_length(self): + """ + Report the sample length as 2 + + Returns + ------- + int + 2 + """ + return 2 + + def validate(self): + """ + No validation. + + Returns + ------- + None + """ + return None + + +def test_transform_rv_validation(): + """ + Test that a TransformedRandomVariable validation + works as expected. + """ + + base_rv = DistributionalRV(dist.Normal(0, 1), "test_normal") + base_rv.sample_length = lambda: 1 # numpydoc ignore=GL08 + + l2_rv = LengthTwoRV() + + test_transforms = [t.IdentityTransform(), t.ExpTransform()] + + for tr in test_transforms: + my_rv = TransformedRandomVariable("test_transformed_rv", base_rv, tr) + assert isinstance(my_rv.transforms, tuple) + assert len(my_rv.transforms) == 1 + assert my_rv.sample_length() == 1 + not_callable_err = "All entries in self.transforms " "must be callable" + sample_length_err = "There must be exactly as many transformations" + with pytest.raises(ValueError, match=sample_length_err): + _ = TransformedRandomVariable( + "should_error_due_to_too_many_transforms", base_rv, (tr, tr) + ) + with pytest.raises(ValueError, match=sample_length_err): + _ = TransformedRandomVariable( + "should_error_due_to_too_few_transforms", l2_rv, tr + ) + with pytest.raises(ValueError, match=sample_length_err): + _ = TransformedRandomVariable( + "should_also_error_due_to_too_few_transforms", l2_rv, (tr,) + ) + with pytest.raises(ValueError, match=not_callable_err): + _ = TransformedRandomVariable( + "should_error_due_to_not_callable", l2_rv, (1,) + ) + with pytest.raises(ValueError, match=not_callable_err): + _ = TransformedRandomVariable( + "should_error_due_to_not_callable", base_rv, (1,) + ) + + +def test_transforms_applied_at_sampling(): + """ + Test that TransformedRandomVariable + instances correctly apply their specified + transformations at sampling + """ + norm_rv = DistributionalRV(dist.Normal(0, 1), "test_normal") + norm_rv.sample_length = lambda: 1 + + l2_rv = LengthTwoRV() + + for tr in [ + t.IdentityTransform(), + t.ExpTransform(), + t.ExpTransform().inv, + t.ScaledLogitTransform(5), + ]: + tr_norm = TransformedRandomVariable("transformed_normal", norm_rv, tr) + + tr_l2 = TransformedRandomVariable( + "transformed_length_2", l2_rv, (tr, t.ExpTransform()) + ) + + with numpyro.handlers.seed(rng_seed=5): + norm_base_sample = norm_rv.sample() + l2_base_sample = l2_rv.sample() + with numpyro.handlers.seed(rng_seed=5): + norm_transformed_sample = tr_norm.sample() + l2_transformed_sample = tr_l2.sample() + + assert_almost_equal( + (tr(norm_base_sample[0]),), norm_transformed_sample + ) + assert_almost_equal( + (tr(l2_base_sample[0]), t.ExpTransform()(l2_base_sample[1])), + l2_transformed_sample, + )