Skip to content

Commit

Permalink
create randomvariable module (#412)
Browse files Browse the repository at this point in the history
* testing convolve mode

* update tutorial to work with convolve mode valid

* update latent admissions test

* update DOW tutorial for convolve mode valid

* update hosp model tests

* create helper function for convolve and add tests

* forgot to run precommit earlier

* update test for model with DOW effect

* renaming helper function, add n_initialization_point

* create randomvariable module

* make suffixes across variables unifrom

* modify import statements

* missed few imports

* pre-commit changes

* update metaclass.py

* add randomvariable.rst
  • Loading branch information
sbidari authored Aug 26, 2024
1 parent 1a86104 commit 79706e9
Show file tree
Hide file tree
Showing 28 changed files with 663 additions and 629 deletions.
1 change: 1 addition & 0 deletions docs/source/msei_reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Reference
model
latent
process
randomvariable
observation
datasets
msei
Expand Down
7 changes: 7 additions & 0 deletions docs/source/msei_reference/randomvariable.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Random Variables
===========

.. automodule:: pyrenew.randomvariable
:members:
:undoc-members:
:show-inheritance:
17 changes: 7 additions & 10 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ from pyrenew.latent import (
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import (
RandomVariable,
DistributionalRV,
TransformedRandomVariable,
)
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam
```
Expand Down Expand Up @@ -64,7 +61,7 @@ flowchart LR
subgraph latent[Latent module]
inf["latent_infections_rv\n(Infections)"]
i0["I0_rv\n(DistributionalRV)"]
i0["I0_rv\n(DistributionalVariable)"]
end
subgraph process[Process module]
Expand Down Expand Up @@ -126,7 +123,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", distribution=dist.LogNormal(2.5, 1)),
DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -142,17 +139,17 @@ class MyRt(RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = TransformedRandomVariable(
rt_rv = TransformedVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalRV(
step_rv=DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalRV(
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
Expand Down
24 changes: 12 additions & 12 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
```{python}
# | label: latent-hosp
# | code-fold: true
from pyrenew import latent, deterministic, metaclass
from pyrenew import latent, deterministic, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int_array
)
hosp_rate = metaclass.DistributionalRV(
hosp_rate = randomvariable.DistributionalVariable(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
Expand All @@ -81,7 +81,7 @@ n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
Expand Down Expand Up @@ -113,11 +113,11 @@ class MyRt(metaclass.RandomVariable):
sd_rt, *_ = self.sd_rv()
# Random walk step
step_rv = metaclass.DistributionalRV(
step_rv = randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
)
rt_init_rv = metaclass.DistributionalRV(
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
Expand All @@ -128,7 +128,7 @@ class MyRt(metaclass.RandomVariable):
)
# Transforming the random walk to the Rt scale
rt_rv = metaclass.TransformedRandomVariable(
rt_rv = randomvariable.TransformedVariable(
name="Rt_rv",
base_rv=base_rv,
transforms=transformation.ExpTransform(),
Expand All @@ -139,7 +139,7 @@ class MyRt(metaclass.RandomVariable):
rtproc = MyRt(
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025)
)
)
Expand All @@ -152,9 +152,9 @@ rtproc = MyRt(
# | code-fold: true
# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = metaclass.TransformedRandomVariable(
nb_conc_rv = randomvariable.TransformedVariable(
"concentration",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
Expand Down Expand Up @@ -212,16 +212,16 @@ out = hosp_model.plot_posterior(

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect.

For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedRandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]:
For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]:

[^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html).

```{python}
# | label: weekly-effect
# Instantiating the day-of-the-week effect
dayofweek_effect = metaclass.TransformedRandomVariable(
dayofweek_effect = randomvariable.TransformedVariable(
name="dayofweek_effect",
base_rv=metaclass.DistributionalRV(
base_rv=randomvariable.DistributionalVariable(
name="dayofweek_effect_raw",
distribution=dist.Dirichlet(jnp.ones(7)),
),
Expand Down
15 changes: 6 additions & 9 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RandomWalk
from pyrenew.metaclass import (
RandomVariable,
DistributionalRV,
TransformedRandomVariable,
)
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
Expand All @@ -53,7 +50,7 @@ feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)),
DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.05),
Expand All @@ -75,17 +72,17 @@ class MyRt(RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = TransformedRandomVariable(
rt_rv = TransformedVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalRV(
step_rv=DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalRV(
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
Expand Down
18 changes: 9 additions & 9 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ With these two in hand, we can start building the model. First, we will define t

```{python}
# | label: latent-hosp
from pyrenew import latent, deterministic, metaclass
from pyrenew import latent, deterministic, metaclass, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int_array
)
hosp_rate = metaclass.DistributionalRV(
hosp_rate = randomvariable.DistributionalVariable(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
Expand All @@ -155,7 +155,7 @@ latent_hosp = latent.HospitalAdmissions(
)
```

The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components:
The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalVariable` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components:

```{python}
# | label: initializing-rest-of-model
Expand All @@ -171,7 +171,7 @@ latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
Expand All @@ -194,17 +194,17 @@ class MyRt(metaclass.RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = metaclass.TransformedRandomVariable(
rt_rv = randomvariable.TransformedVariable(
name="log_rt_random_walk",
base_rv=process.RandomWalk(
name="log_rt",
step_rv=metaclass.DistributionalRV(
step_rv=randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=transformation.ExpTransform(),
)
rt_init_rv = metaclass.DistributionalRV(
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
Expand All @@ -218,9 +218,9 @@ rtproc = MyRt()
# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = metaclass.TransformedRandomVariable(
nb_conc_rv = randomvariable.TransformedVariable(
"concentration",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The `PeriodicBroadcaster` class can also be used to repeat a sequence as a whole

```{python}
import numpyro.distributions as dist
from pyrenew import transformation, metaclass
from pyrenew import transformation, randomvariable
# Building the transformed prior: Dirichlet * 7
mysimplex = dist.TransformedDistribution(
Expand All @@ -76,7 +76,7 @@ mysimplex = dist.TransformedDistribution(
# Constructing the day of week effect
dayofweek = process.DayOfWeekEffect(
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(
quantity_to_broadcast=randomvariable.DistributionalVariable(
name="simp", distribution=mysimplex
),
t_start=0,
Expand Down
Loading

0 comments on commit 79706e9

Please sign in to comment.