From 9298493f6b06f8eb39494b30a67c4445d089d1ed Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 17 Jul 2024 15:41:27 -0400 Subject: [PATCH 01/33] Refactor simplerandomwalk.py and associated tests, simplify process.rst --- docs/source/msei_reference/process.rst | 31 +---------- model/src/pyrenew/process/simplerandomwalk.py | 51 +++++++++++-------- model/src/test/test_random_walk.py | 48 +++++++++++++---- 3 files changed, 69 insertions(+), 61 deletions(-) diff --git a/docs/source/msei_reference/process.rst b/docs/source/msei_reference/process.rst index b004fe46..33a520c4 100644 --- a/docs/source/msei_reference/process.rst +++ b/docs/source/msei_reference/process.rst @@ -1,36 +1,7 @@ Random Process ============== -AR Processes ------------- - -.. automodule:: pyrenew.process.ar +.. automodule:: pyrenew.process :members: :undoc-members: :show-inheritance: - -First Difference (AR) ---------------------- - -.. automodule:: pyrenew.process.firstdifferencear - :members: - :undoc-members: - :show-inheritance: - -Reproduction Number Random Walk -------------------------------- - -.. automodule:: pyrenew.process.rtrandomwalk - :members: - :undoc-members: - :show-inheritance: - -Simple Random Walk ------------------- - -.. automodule:: pyrenew.process.simplerandomwalk - :members: - :undoc-members: - :show-inheritance: - -.. todo:: Determine order and naming of these modules. diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index c18bec67..87d36dab 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -2,8 +2,6 @@ # numpydoc ignore=GL08 import jax.numpy as jnp -import numpyro as npro -import numpyro.distributions as dist from numpyro.contrib.control_flow import scan from pyrenew.metaclass import RandomVariable @@ -12,45 +10,58 @@ class SimpleRandomWalkProcess(RandomVariable): """ Class for a Markovian random walk with an a - arbitrary step distribution + step distribution """ def __init__( self, - error_distribution: dist.Distribution, + name: str, + step_rv: RandomVariable, + init_rv: RandomVariable, + t_start: int = None, + t_unit: int = None, ) -> None: """ Default constructor Parameters ---------- - error_distribution : dist.Distribution - Passed to numpyro.sample. + name : str + A name for the random variable, used to + name sites within it in :fun :`numpyro.sample()` + calls. + step_rv : RandomVariable + RandomVariable representing the step distribution. + init_rv : RandomVariable + RandomVariable representing the initial value of + the process + t_start : int + See :class:`RandomVariable` + t_unit : int + See :class:`RandomVariable` Returns ------- None """ - self.error_distribution = error_distribution + self.name = name + self.step_rv = step_rv + self.init_rv = init_rv + self.t_start = t_start + self.t_unit = t_unit def sample( self, n_timepoints: int, - name: str = "randomwalk", - init: float = None, **kwargs, ) -> tuple: """ - Samples from the randomwalk + Sample from the random walk. Parameters ---------- n_timepoints : int - Length of the walk. - name : str, optional - Passed to numpyro.sample, by default "randomwalk" - init : float, optional - Initial point of the walk, by default None + Length of the walk to sample. **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. @@ -61,12 +72,11 @@ def sample( With a single array of shape (n_timepoints,). """ - if init is None: - init = npro.sample(name + "_init", self.error_distribution) + init, *_ = self.init_rv.sample(**kwargs) def transition(x_prev, _): # numpydoc ignore=GL08 - diff = npro.sample(name + "_diffs", self.error_distribution) + diff, *_ = self.step_rv.sample(**kwargs) x_curr = x_prev + diff return x_curr, x_curr @@ -76,11 +86,12 @@ def transition(x_prev, _): xs=jnp.arange(n_timepoints - 1), ) - return (jnp.hstack([init, x]),) + return (jnp.hstack([init, x.flatten()]),) @staticmethod def validate(): """ - Validates inputted parameters, implementation pending. + Validates input parameters, implementation pending. """ + super().validate() return None diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index c2dcb186..d6d6ad09 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -4,6 +4,8 @@ import numpyro import numpyro.distributions as dist from numpy.testing import assert_almost_equal +from pyrenew.deterministic import DeterministicVariable +from pyrenew.metaclass import DistributionalRV from pyrenew.process import SimpleRandomWalkProcess @@ -12,16 +14,32 @@ def test_rw_can_be_sampled(): Check that a simple random walk can be initialized and sampled from """ - rw_normal = SimpleRandomWalkProcess(dist.Normal(0, 1)) + init_rv_rand = DistributionalRV(dist.Normal(1, 0.5), "init_rv_rand") + init_rv_fixed = DeterministicVariable(50.0, "init_rv_fixed") + + step_rv = DistributionalRV(dist.Normal(0, 1), "rw_step") + + rw_init_rand = SimpleRandomWalkProcess( + "rw_rand_init", step_rv=step_rv, init_rv=init_rv_rand + ) + + rw_init_fixed = SimpleRandomWalkProcess( + "rw_fixed_init", step_rv=step_rv, init_rv=init_rv_fixed + ) with numpyro.handlers.seed(rng_seed=62): - # can sample with and without inits - ans0 = rw_normal(n_timepoints=3532, init=50.0) - ans1 = rw_normal(n_timepoints=5023) + # can sample with a fixed init + # and with a random init + ans_rand = rw_init_rand(n_timepoints=3532) + ans_fixed = rw_init_fixed(n_timepoints=5023) - # check that the samples are of the right shape - assert ans0[0].shape == (3532,) - assert ans1[0].shape == (5023,) + # check that the samples are of the right shape + assert ans_rand[0].shape == (3532,) + assert ans_fixed[0].shape == (5023,) + + # check that fixing inits works + assert_almost_equal(ans_fixed[0][0], init_rv_fixed.vars) + assert ans_rand[0][0] != init_rv_fixed.vars def test_rw_samples_correctly_distributed(): @@ -34,10 +52,18 @@ def test_rw_samples_correctly_distributed(): for step_mean, step_sd in zip( [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): - rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - rw_init = 532.0 + rw_init_val = 532.0 + rw_normal = SimpleRandomWalkProcess( + name="rw_normal_test", + step_rv=DistributionalRV( + dist=dist.Normal(loc=step_mean, scale=step_sd), + name="rw_normal_dist", + ), + init_rv=DeterministicVariable(rw_init_val, "init_rv_fixed"), + ) + with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal(n_timepoints=n_samples, init=rw_init) + samples, *_ = rw_normal(n_timepoints=n_samples) # Checking the shape assert samples.shape == (n_samples,) @@ -60,4 +86,4 @@ def test_rw_samples_correctly_distributed(): assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) # first value should be the init value - assert_almost_equal(samples[0], rw_init) + assert_almost_equal(samples[0], rw_init_val) From f6ceb6553871b9c0db586c67794a463130dd3699 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 17 Jul 2024 17:25:28 -0400 Subject: [PATCH 02/33] Add TransformedRandomVariable metaclass and tests --- model/src/pyrenew/metaclass.py | 117 +++++++++++++++++ model/src/test/test_transformed_rv_class.py | 138 ++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 model/src/test/test_transformed_rv_class.py diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 57d95885..e5990b65 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -17,6 +17,7 @@ from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive from pyrenew.mcmcutils import plot_posterior, spread_draws +from pyrenew.transformation import Transform def _assert_sample_and_rtype( @@ -581,3 +582,119 @@ def prior_predictive( ) return predictive(rng_key, **kwargs) + + +class TransformedRandomVariable(RandomVariable): + """ + Class to represent RandomVariables defined + by taking the output of another RV's + :meth:`RandomVariable.sample()` method + and transforming it by a given transformation + (typically a :class:`Transform`) + """ + + def __init__( + self, + name: str, + base_rv: RandomVariable, + transforms: Transform | tuple[Transform], + ): + """ + Default constructor + + Parameters + ---------- + + name : str + A name for the random variable instance + + base_rv : RandomVariable + The underlying (untransformed) RandomVariable + + transforms : Transform + Transformation or tuple of transformations + to apply to the output of + `base_rv.sample()`; single values will be coerced to + a length-one tuple. If a tuple, should be the same + length as the tuple returned by `base_rv.sample()` + + Returns + ------- + None + """ + self.name = name + self.base_rv = base_rv + + if not isinstance(transforms, tuple): + transforms = (transforms,) + self.transforms = transforms + self.validate() + + def sample(self, **kwargs): + """ + Sample method. Call self.base_rv.sample() + and then apply the transforms specified + in self.transforms. + + Parameters + ---------- + **kwargs : + Keyword arguments passed to self.base_rv.sample() + + Returns + ------- + tuple of the same length as the tuple returned by + self.base_rv.sample() + """ + + untransformed_values = self.base_rv.sample(**kwargs) + + return tuple( + t(uv) for t, uv in zip(self.transforms, untransformed_values) + ) + + def sample_length(self): + """ + Sample length for a transformed + random variable must be equal to the + length of self.transforms or + validation will fail. + + Returns + ------- + int + Equal to the length self.transforms + """ + return len(self.transforms) + + def validate(self): + """ + Perform validation checks on a + TransformedRandomVariable instance, + confirming that all transformations + are callable and that the number of + transformations is equal to the sample + length of the base random variable. + + Returns + ------- + None + on successful validation, or raise a ValueError + """ + for t in self.transforms: + if not callable(t): + raise ValueError( + "All entries in self.transforms " "must be callable" + ) + if hasattr(self.base_rv, "sample_length"): + n_transforms = len(self.transforms) + n_entries = self.base_rv.sample_length() + if not n_transforms == n_entries: + raise ValueError( + "There must be exactly as many transformations " + "specified as entries self.transforms as there are " + "entries in the tuple returned by " + "self.base_rv.sample()." + f"Got {n_transforms} transforms and {n_entries} " + "entries" + ) diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py new file mode 100644 index 00000000..cf52b487 --- /dev/null +++ b/model/src/test/test_transformed_rv_class.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +""" +Tests for TransformedRandomVariable class +""" + +import numpyro +import numpyro.distributions as dist +import pyrenew.transformation as t +import pytest +from numpy.testing import assert_almost_equal +from pyrenew.metaclass import ( + DistributionalRV, + RandomVariable, + TransformedRandomVariable, +) + + +class LengthTwoRV(RandomVariable): + """ + Class for a RandomVariable + with sample_length 2 + and values 1 and 5 + """ + + def sample(self, **kwargs): + """ + Deterministic sampling method + that returns a length-2 tuple + + Returns + ------- + tuple + (1, 5) + """ + return (1, 5) + + def sample_length(self): + """ + Report the sample length as 2 + + Returns + ------- + int + 2 + """ + return 2 + + def validate(self): + """ + No validation. + + Returns + ------- + None + """ + return None + + +def test_transform_rv_validation(): + """ + Test that a TransformedRandomVariable validation + works as expected. + """ + + base_rv = DistributionalRV(dist.Normal(0, 1), "test_normal") + base_rv.sample_length = lambda: 1 # numpydoc ignore=GL08 + + l2_rv = LengthTwoRV() + + test_transforms = [t.IdentityTransform(), t.ExpTransform()] + + for tr in test_transforms: + my_rv = TransformedRandomVariable("test_transformed_rv", base_rv, tr) + assert isinstance(my_rv.transforms, tuple) + assert len(my_rv.transforms) == 1 + assert my_rv.sample_length() == 1 + not_callable_err = "All entries in self.transforms " "must be callable" + sample_length_err = "There must be exactly as many transformations" + with pytest.raises(ValueError, match=sample_length_err): + _ = TransformedRandomVariable( + "should_error_due_to_too_many_transforms", base_rv, (tr, tr) + ) + with pytest.raises(ValueError, match=sample_length_err): + _ = TransformedRandomVariable( + "should_error_due_to_too_few_transforms", l2_rv, tr + ) + with pytest.raises(ValueError, match=sample_length_err): + _ = TransformedRandomVariable( + "should_also_error_due_to_too_few_transforms", l2_rv, (tr,) + ) + with pytest.raises(ValueError, match=not_callable_err): + _ = TransformedRandomVariable( + "should_error_due_to_not_callable", l2_rv, (1,) + ) + with pytest.raises(ValueError, match=not_callable_err): + _ = TransformedRandomVariable( + "should_error_due_to_not_callable", base_rv, (1,) + ) + + +def test_transforms_applied_at_sampling(): + """ + Test that TransformedRandomVariable + instances correctly apply their specified + transformations at sampling + """ + norm_rv = DistributionalRV(dist.Normal(0, 1), "test_normal") + norm_rv.sample_length = lambda: 1 + + l2_rv = LengthTwoRV() + + for tr in [ + t.IdentityTransform(), + t.ExpTransform(), + t.ExpTransform().inv, + t.ScaledLogitTransform(5), + ]: + tr_norm = TransformedRandomVariable("transformed_normal", norm_rv, tr) + + tr_l2 = TransformedRandomVariable( + "transformed_length_2", l2_rv, (tr, t.ExpTransform()) + ) + + with numpyro.handlers.seed(rng_seed=5): + norm_base_sample = norm_rv.sample() + l2_base_sample = l2_rv.sample() + with numpyro.handlers.seed(rng_seed=5): + norm_transformed_sample = tr_norm.sample() + l2_transformed_sample = tr_l2.sample() + + assert_almost_equal( + (tr(norm_base_sample[0]),), norm_transformed_sample + ) + assert_almost_equal( + (tr(l2_base_sample[0]), t.ExpTransform()(l2_base_sample[1])), + l2_transformed_sample, + ) From da8ba1b82c939046b72ba415c22141b7f00e0522 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 17 Jul 2024 22:27:36 -0400 Subject: [PATCH 03/33] Clean up metaclass docs, remove DistributionalRVSample class --- model/src/pyrenew/metaclass.py | 74 +++++++++++++++------------------- 1 file changed, 33 insertions(+), 41 deletions(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index e5990b65..0651d341 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -5,14 +5,15 @@ """ from abc import ABCMeta, abstractmethod -from typing import NamedTuple, get_type_hints +from typing import get_type_hints import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np -import numpyro as npro +import numpyro +import numpyro.distributions as dist import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive @@ -210,33 +211,15 @@ def __call__(self, **kwargs): return self.sample(**kwargs) -class DistributionalRVSample(NamedTuple): - """ - Named tuple for the sample method of DistributionalRV - - Attributes - ---------- - value : ArrayLike - Sampled value from the distribution. - """ - - value: ArrayLike | None = None - - def __repr__(self) -> str: - """ - Representation of the DistributionalRVSample - """ - return f"DistributionalRVSample(value={self.value})" - - class DistributionalRV(RandomVariable): """ - Wrapper class for random variables that sample from a single `numpyro.distributions.Distribution`. + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution`. """ def __init__( self, - dist: npro.distributions.Distribution, + dist: dist.Distribution, name: str, ): """ @@ -244,7 +227,7 @@ def __init__( Parameters ---------- - dist : npro.distributions.Distribution + dist : dist.Distribution Distribution of the random variable. name : str Name of the random variable. @@ -266,7 +249,7 @@ def validate(dist: any) -> None: """ Validation of the distribution to be implemented in subclasses. """ - if not isinstance(dist, npro.distributions.Distribution): + if not isinstance(dist, dist.Distribution): raise ValueError( "dist should be an instance of " f"numpyro.distributions.Distribution, got {dist}" @@ -278,25 +261,27 @@ def sample( self, obs: ArrayLike | None = None, **kwargs, - ) -> DistributionalRVSample: + ) -> tuple: """ Sample from the distribution. Parameters ---------- obs : ArrayLike, optional - Observations passed as the `obs` argument to `numpyro.sample()`. Default `None`. + Observations passed as the `obs` argument to + :fun:`numpyro.sample()`. Default `None`. **kwargs : dict, optional - Additional keyword arguments passed through to internal sample calls, - should there be any. + Additional keyword arguments passed through + to internal sample calls, should there be any. Returns ------- - DistributionalRVSample + tuple + Containing the sampled from the distribution. """ - return DistributionalRVSample( - value=jnp.atleast_1d( - npro.sample( + return ( + jnp.atleast_1d( + numpyro.sample( name=self.name, fn=self.dist, obs=obs, @@ -417,9 +402,13 @@ def run( Parameters ---------- nuts_args : dict, optional - Dictionary of arguments passed to the NUTS. Defaults to None. + Dictionary of arguments passed to the + :class:`numpyro.infer.NUTS` kernel. + Defaults to None. mcmc_args : dict, optional - Dictionary of passed to the MCMC sampler. Defaults to None. + Dictionary of arguments passed to the + :class:`numpyro.infer.MCMC` constructor. + Defaults to None. Returns ------- @@ -449,14 +438,14 @@ def print_summary( exclude_deterministic: bool = True, ) -> None: """ - A wrapper of MCMC.print_summary + A wrapper of :meth:`numpyro.infer.MCMC.print_summary` Parameters ---------- prob : float, optional - The acceptance probability of print_summary. Defaults to 0.9 + The width of the credible interval to show. Default 0.9 exclude_deterministic : bool, optional - Whether to print deterministic variables in the summary. + Whether to print deterministic sites in the summary. Defaults to True. Returns @@ -512,16 +501,19 @@ def posterior_predictive( **kwargs, ) -> dict: """ - A wrapper for numpyro.infer.Predictive to generate posterior predictive samples. + A wrapper for :class:`numpyro.infer.Predictive` to generate + posterior predictive samples. Parameters ---------- rng_key : ArrayLike, optional Random key for the Predictive function call. Defaults to None. numpyro_predictive_args : dict, optional - Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor. + Dictionary of arguments to be passed to the + :class:`numpyro.inference.Predictive` constructor. **kwargs - Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive + Additional named arguments passed to the + `__call__()` method of :class:`numpyro.infer.Predictive` Returns ------- From c0d1cdcbe402beb4742dc1262e451268ed5de73c Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 17 Jul 2024 22:30:58 -0400 Subject: [PATCH 04/33] fix dist clash --- model/src/pyrenew/metaclass.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 0651d341..16c7b8f7 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -13,7 +13,6 @@ import matplotlib.pyplot as plt import numpy as np import numpyro -import numpyro.distributions as dist import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive @@ -219,9 +218,9 @@ class DistributionalRV(RandomVariable): def __init__( self, - dist: dist.Distribution, + dist: numpyro.distributions.Distribution, name: str, - ): + ) -> None: """ Default constructor for DistributionalRV. @@ -249,7 +248,7 @@ def validate(dist: any) -> None: """ Validation of the distribution to be implemented in subclasses. """ - if not isinstance(dist, dist.Distribution): + if not isinstance(dist, numpyro.distributions.Distribution): raise ValueError( "dist should be an instance of " f"numpyro.distributions.Distribution, got {dist}" From 91230c7e3b94da9fccf4aaf8673e44102065e17d Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 17 Jul 2024 23:10:09 -0400 Subject: [PATCH 05/33] Rewrite all tests to pass without RtRandomWalkProcess --- model/src/pyrenew/metaclass.py | 2 +- .../pyrenew/model/rtinfectionsrenewalmodel.py | 4 +- model/src/pyrenew/process/__init__.py | 2 - model/src/pyrenew/process/rtrandomwalk.py | 125 ------------------ model/src/test/test_forecast.py | 17 ++- model/src/test/test_latent_admissions.py | 17 ++- model/src/test/test_latent_infections.py | 16 ++- model/src/test/test_model_basic_renewal.py | 60 +++++---- model/src/test/test_model_hospitalizations.py | 93 +++++++------ model/src/test/test_predictive.py | 19 ++- model/src/test/test_random_key.py | 20 +-- 11 files changed, 144 insertions(+), 231 deletions(-) delete mode 100644 model/src/pyrenew/process/rtrandomwalk.py diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 16c7b8f7..03b417ea 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -621,7 +621,7 @@ def __init__( self.transforms = transforms self.validate() - def sample(self, **kwargs): + def sample(self, **kwargs) -> tuple: """ Sample method. Call self.base_rv.sample() and then apply the transforms specified diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 8a80fbc1..f6f9846f 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -164,7 +164,8 @@ def sample( Notes ----- - Either `data_observed_infections` or `n_timepoints_to_simulate` must be specified, not both. + Either `data_observed_infections` or `n_timepoints_to_simulate` + must be specified, not both. Returns ------- @@ -238,6 +239,7 @@ def sample( jnp.nan, pad_direction="start", ) + npro.deterministic("Rt", Rt) return RtInfectionsRenewalSample( Rt=Rt, diff --git a/model/src/pyrenew/process/__init__.py b/model/src/pyrenew/process/__init__.py index 1613f168..bad08343 100644 --- a/model/src/pyrenew/process/__init__.py +++ b/model/src/pyrenew/process/__init__.py @@ -9,13 +9,11 @@ RtPeriodicDiffProcess, RtWeeklyDiffProcess, ) -from pyrenew.process.rtrandomwalk import RtRandomWalkProcess from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess __all__ = [ "ARProcess", "FirstDifferenceARProcess", - "RtRandomWalkProcess", "SimpleRandomWalkProcess", "RtPeriodicDiffProcess", "RtWeeklyDiffProcess", diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py deleted file mode 100644 index 5722fd54..00000000 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- -# numpydoc ignore=GL08 - -import numpyro as npro -import numpyro.distributions as dist -import pyrenew.transformation as t -from pyrenew.metaclass import RandomVariable -from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess - - -class RtRandomWalkProcess(RandomVariable): - r"""Rt Randomwalk Process - - Notes - ----- - - The process is defined as follows: - - .. math:: - - Rt(0) &\sim \text{Rt0_dist} \\ - Rt(t) &\sim \text{Rt_transform}(\text{Rt_transformed_rw}(t)) - """ - - def __init__( - self, - Rt0_dist: dist.Distribution, - Rt_rw_dist: dist.Distribution, - Rt_transform: t.Transform | None = None, - ) -> None: - """ - Default constructor - - Parameters - ---------- - Rt0_dist : dist.Distribution - Initial distribution of Rt. - Rt_rw_dist : dist.Distribution - Randomwalk process. - Rt_transform : numpyro.distributions.transformers.Transform, optional - Transformation applied to the sampled Rt0. If None, the identity - transformation is used. - - Returns - ------- - None - """ - if Rt_transform is None: - Rt_transform = t.IdentityTransform() - - RtRandomWalkProcess.validate(Rt0_dist, Rt_transform, Rt_rw_dist) - - self.Rt0_dist = Rt0_dist - self.Rt_transform = Rt_transform - self.Rt_rw_dist = Rt_rw_dist - - return None - - @staticmethod - def validate( - Rt0_dist: dist.Distribution, - Rt_transform: t.Transform, - Rt_rw_dist: dist.Distribution, - ) -> None: - """ - Validates Rt0_dist, Rt_transform, and Rt_rw_dist. - - Parameters - ---------- - Rt0_dist : dist.Distribution, optional - Initial distribution of Rt, expected dist.Distribution - Rt_transform : numpyro.distributions.transforms.Transform - Transformation applied to the sampled Rt0. - Rt_rw_dist : any - Randomwalk process, expected dist.Distribution. - - Returns - ------- - None - - Raises - ------ - AssertionError - If Rt0_dist or Rt_rw_dist are not dist.Distribution or if - Rt_transform is not numpyro.distributions.transforms.Transform. - """ - assert isinstance(Rt0_dist, dist.Distribution) - assert isinstance(Rt_transform, t.Transform) - assert isinstance(Rt_rw_dist, dist.Distribution) - - def sample( - self, - n_timepoints: int, - **kwargs, - ) -> tuple: - """ - Generate samples from the process - - Parameters - ---------- - n_timepoints : int - Number of timepoints to sample. - **kwargs : dict, optional - Additional keyword arguments passed through to internal sample() - calls, should there be any. - - Returns - ------- - tuple - With a single array of shape (n_timepoints,). - """ - - Rt0 = npro.sample("Rt0", self.Rt0_dist) - - Rt0_trans = self.Rt_transform(Rt0) - Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist) - Rt_trans_ts, *_ = Rt_trans_proc( - n_timepoints=n_timepoints, - name="Rt_transformed_rw", - init=Rt0_trans, - ) - - Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts)) - - return (Rt,) diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 523297b3..02255544 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -13,10 +13,10 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess def test_forecast(): @@ -31,11 +31,16 @@ def test_forecast(): ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 6d5abea1..a9e2bb5c 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -9,8 +9,8 @@ from pyrenew import transformation as t from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalRV -from pyrenew.process import RtRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.process import SimpleRandomWalkProcess def test_admissions_sample(): @@ -22,11 +22,16 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions np.random.seed(223) - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): sim_rt, *_ = rt(n_timepoints=30) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 0c9be7d4..04638d59 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -9,7 +9,8 @@ import pyrenew.transformation as t import pytest from pyrenew.latent import Infections -from pyrenew.process import RtRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.process import SimpleRandomWalkProcess def test_infections_as_deterministic(): @@ -19,11 +20,16 @@ def test_infections_as_deterministic(): """ np.random.seed(223) - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): sim_rt, *_ = rt(n_timepoints=30) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index d0961eb3..ba66cda3 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -16,15 +16,39 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess + + +def get_default_rt(): + """ + Helper function to create a default Rt + RandomVariable for this testing session. + + Returns + ------- + TransformedRandomVariable : + A log-scale random walk with fixed + init value and step size priors + """ + return TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), + ) def test_model_basicrenewal_no_timepoints_or_observations(): """ - Test that the basic renewal model does not run without either n_timepoints_to_simulate or observed_admissions + Test that the basic renewal model does not run + without either n_timepoints_to_simulate or + observed_admissions """ gen_int = DeterministicPMF( @@ -37,11 +61,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -74,11 +94,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): observed_infections = PoissonObservation("possion_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -119,11 +135,7 @@ def test_model_basicrenewal_no_obs_model(): latent_infections = Infections() - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model0 = RtInfectionsRenewalModel( gen_int_rv=gen_int, @@ -197,11 +209,7 @@ def test_model_basicrenewal_with_obs_model(): observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -251,11 +259,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + rt = get_default_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 7eb6a8c4..be352bcf 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import jax.random as jr import numpy as np -import numpyro as npro +import numpyro import numpyro.distributions as dist import polars as pl import pytest @@ -21,10 +21,36 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import ( + DistributionalRV, + RandomVariable, + TransformedRandomVariable, +) from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess + + +def get_default_rt(): + """ + Helper function to create a default Rt + RandomVariable for this testing session. + + Returns + ------- + TransformedRandomVariable : + A log-scale random walk with fixed + init value and step size priors + """ + return TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), + ) class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 @@ -39,13 +65,16 @@ def validate(self): # numpydoc ignore=GL08 def sample(self, **kwargs): # numpydoc ignore=GL08 return ( - npro.sample(name=self.name, fn=dist.Uniform(high=0.99, low=0.01)), + numpyro.sample( + name=self.name, fn=dist.Uniform(high=0.99, low=0.01) + ), ) def test_model_hosp_no_timepoints_or_observations(): """ - Checks that the Hospitalization model does not run without either n_timepoints_to_simulate or observed_admissions + Checks that the Hospitalization model does not run + without either n_timepoints_to_simulate or observed_admissions """ gen_int = DeterministicPMF( @@ -55,11 +84,8 @@ def test_model_hosp_no_timepoints_or_observations(): I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -105,7 +131,7 @@ def test_model_hosp_no_timepoints_or_observations(): ) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): with pytest.raises(ValueError, match="Either"): model1.sample( n_timepoints_to_simulate=None, data_observed_admissions=None @@ -124,11 +150,8 @@ def test_model_hosp_both_timepoints_and_observations(): I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -174,7 +197,7 @@ def test_model_hosp_both_timepoints_and_observations(): ) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): with pytest.raises(ValueError, match="Cannot pass both"): model1.sample( n_timepoints_to_simulate=30, @@ -200,11 +223,8 @@ def test_model_hosp_no_obs_model(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + inf_hosp = DeterministicPMF( jnp.array( [ @@ -250,13 +270,13 @@ def test_model_hosp_no_obs_model(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model0_samp = model0.sample(n_timepoints_to_simulate=30) model0.hosp_admission_obs_process_rv = NullObservation() np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) np.testing.assert_array_almost_equal(model0_samp.Rt, model1_samp.Rt) @@ -310,11 +330,7 @@ def test_model_hosp_with_obs_model(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -361,7 +377,7 @@ def test_model_hosp_with_obs_model(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample(n_timepoints_to_simulate=30) model1.run( @@ -400,11 +416,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -462,7 +474,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample(n_timepoints_to_simulate=30) model1.run( @@ -503,11 +515,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) latent_infections = Infections() - Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), - ) + Rt_process = get_default_rt() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -575,7 +584,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Sampling and fitting model 0 (with no obs for infections) np.random.seed(223) - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample( n_timepoints_to_simulate=n_obs_to_generate, padding=pad_size ) diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 1089974e..d98269a1 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -14,12 +14,12 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess -pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) +pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(pmf_array, name="gen_int") I0 = InfectionInitializationProcess( "I0_initialization", @@ -29,11 +29,16 @@ ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") -rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) + model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index 5f1c9986..6a44e4ff 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import jax.random as jr import numpy as np -import numpyro as npro +import numpyro import numpyro.distributions as dist import pyrenew.transformation as t from numpy.testing import assert_array_equal, assert_raises @@ -18,10 +18,10 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess def create_test_model(): # numpydoc ignore=GL08 @@ -35,10 +35,14 @@ def create_test_model(): # numpydoc ignore=GL08 ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), + rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) model = RtInfectionsRenewalModel( I0_rv=I0, @@ -99,7 +103,7 @@ def test_rng_keys_produce_correct_samples(): ] # sample only a single model and use that model's samples # as the observed_infections for the rest of the models - with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): model_sample = models[0].sample( n_timepoints_to_simulate=n_timepoints_to_simulate[0] ) From fcd8a2c2469f5e1ce1455a7a718ee913e71f9e0a Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 17 Jul 2024 23:53:30 -0400 Subject: [PATCH 06/33] Adapt all tutorials --- docs/source/tutorials/basic_renewal_model.qmd | 22 +++++++------ docs/source/tutorials/extending_pyrenew.qmd | 16 ++++++---- .../tutorials/hospital_admissions_model.qmd | 16 +++++++--- docs/source/tutorials/pyrenew_demo.qmd | 31 ++++++++++++------- model/pyproject.toml | 1 + 5 files changed, 56 insertions(+), 30 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 2e324a08..4efec4b5 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -16,7 +16,7 @@ import jax.numpy as jnp import numpy as np import numpyro as npro import numpyro.distributions as dist -from pyrenew.process import RtRandomWalkProcess +from pyrenew.process import SimpleRandomWalkProcess from pyrenew.latent import ( Infections, InfectionInitializationProcess, @@ -25,7 +25,7 @@ from pyrenew.latent import ( from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable import pyrenew.transformation as t ``` @@ -101,7 +101,7 @@ To initialize these five components within the renewal modeling framework, we es (2) an instance of the `InfectionInitializationProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `InitializeInfectionsZeroPad`, the latent infections before this time are assumed to be 0. -(3) an instance of the `RtRandomWalkProcess` class with default values +(3) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. (4) an instance of the `Infections` class with default values, and @@ -121,11 +121,15 @@ I0 = InfectionInitializationProcess( t_unit=1, ) -# (3) The random process for Rt -rt_proc = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +# (3) The random walk on log Rt +rt_proc = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) # (4) Latent infection process (which will use 1 and 2) @@ -156,7 +160,7 @@ The following diagram summarizes how the modules interact via composition; notab flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionInitializationProcess)"] - rt["(3) rt_proc\n(RtRandomWalkProcess)"] + rt["(3) rt_proc\n(TransformedRandomVariable)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 468bd1b9..65d51052 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -28,8 +28,8 @@ import numpyro.distributions as dist from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import InfectionsWithFeedback from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.process import RtRandomWalkProcess -from pyrenew.metaclass import DistributionalRV +from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsExponentialGrowth, @@ -60,10 +60,14 @@ latent_infections = InfectionsWithFeedback( infection_feedback_pmf=gen_int, ) -rt = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +rt = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) ``` diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index f550ea98..02defaa6 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -169,10 +169,18 @@ I0 = InfectionInitializationProcess( # Generation interval and Rt gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") -rtproc = process.RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=transformation.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +rtproc = metaclass.TransformedRandomVariable( + "Rt_rv", + base_rv=process.SimpleRandomWalkProcess( + name="log_rt", + step_rv=metaclass.DistributionalRV( + dist.Normal(0, 0.025), "rw_step_rv" + ), + init_rv=metaclass.DistributionalRV( + dist.Normal(0, 0.2), "init_log_Rt_rv" + ), + ), + transforms=transformation.ExpTransform(), ) # The observation model diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index 858b41a6..9bb6677b 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -35,6 +35,7 @@ import numpyro.distributions as dist ```{python} from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV ``` To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. @@ -43,7 +44,12 @@ To understand the simple random walk process underlying the sampling within the # | label: fig-randwalk # | fig-cap: Random walk example np.random.seed(3312) -q = SimpleRandomWalkProcess(dist.Normal(0, 0.001)) +q = SimpleRandomWalkProcess( + "example_random_walk", + step_rv=DistributionalRV(dist.Normal(0, 0.001), "step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.001), "init_rv"), +) + with seed(rng_seed=np.random.randint(0, 1000)): q_samp = q(n_timepoints=100) @@ -57,7 +63,6 @@ from pyrenew.latent import ( Infections, HospitalAdmissions, ) -from pyrenew.metaclass import DistributionalRV ``` Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal model (Rt) random walk process: @@ -66,11 +71,11 @@ Additionally, import several classes from Pyrenew, including a Poisson observati from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.model import HospitalAdmissionsModel -from pyrenew.process import RtRandomWalkProcess from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsZeroPad, ) +from pyrenew.metaclass import TransformedRandomVariable import pyrenew.transformation as t ``` @@ -84,9 +89,10 @@ To initialize the model, we first define initial conditions, including: 4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05. -5) hospitalization observation process, modeled with a Poisson distribution +5) hospitalization observation process, modeled with a Poisson distribution + +6) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. -6) an Rt random walk process with default settings ```{python} # Initializing model components: @@ -126,12 +132,15 @@ latent_admissions = HospitalAdmissions( # 5) An observation process for the hospital admissions admissions_process = PoissonObservation("poisson_rv") -# 6) A random walk process (it could be deterministic using -# pyrenew.process.DeterministicProcess()) -Rt_process = RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=t.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), +# 6) The random walk on log Rt +Rt_process = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), + init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), + ), + transforms=t.ExpTransform(), ) ``` diff --git a/model/pyproject.toml b/model/pyproject.toml index f7f963a7..428efeb3 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -34,6 +34,7 @@ pytest-cov = "^5.0.0" pytest-mpl = "^0.17.0" numpydoc = "^1.7.0" arviz = "^0.18.0" +quarto = "^0.1.0" [tool.numpydoc_validation] checks = [ From 62b50018e7b08ac69e8295d40645e251574a5e4f Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 11:44:56 -0400 Subject: [PATCH 07/33] Custom Rt RV in tutorials --- docs/source/tutorials/basic_renewal_model.qmd | 46 +++++++++++----- .../tutorials/hospital_admissions_model.qmd | 53 ++++++++++++------- 2 files changed, 67 insertions(+), 32 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 4efec4b5..f66746c5 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -14,7 +14,7 @@ We start by loading the needed components to build a basic renewal model: # | warning: false import jax.numpy as jnp import numpy as np -import numpyro as npro +import numpyro import numpyro.distributions as dist from pyrenew.process import SimpleRandomWalkProcess from pyrenew.latent import ( @@ -25,7 +25,11 @@ from pyrenew.latent import ( from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.metaclass import ( + RandomVariable, + DistributionalRV, + TransformedRandomVariable, +) import pyrenew.transformation as t ``` @@ -121,16 +125,32 @@ I0 = InfectionInitializationProcess( t_unit=1, ) -# (3) The random walk on log Rt -rt_proc = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), - ), - transforms=t.ExpTransform(), -) + +# (3) The random walk on log Rt, with an inferred s.d. +class MyRt(RandomVariable): + + def validate(self): + pass + + def sample(self, n_timepoints: int, **kwargs) -> tuple: + sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) + + rt_rv = TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(dist.Normal(0, sd_rt), "rw_step_rv"), + init_rv=DistributionalRV( + dist.Normal(0, 0.2), "init_log_Rt_rv" + ), + ), + transforms=t.ExpTransform(), + ) + + return rt_rv.sample(n_timepoints=n_timepoints, **kwargs) + + +rt_proc = MyRt() # (4) Latent infection process (which will use 1 and 2) latent_infections = Infections() @@ -178,7 +198,7 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R ```{python} # | label: simulate np.random.seed(223) -with npro.handlers.seed(rng_seed=np.random.randint(1, 60)): +with numpyro.handlers.seed(rng_seed=np.random.randint(1, 60)): sim_data = model1.sample(n_timepoints_to_simulate=30) sim_data diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 02defaa6..b3dd36df 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -153,6 +153,7 @@ from pyrenew.latent import ( InitializeInfectionsExponentialGrowth, ) + # Infection process latent_inf = latent.Infections() I0 = InfectionInitializationProcess( @@ -169,19 +170,34 @@ I0 = InfectionInitializationProcess( # Generation interval and Rt gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") -rtproc = metaclass.TransformedRandomVariable( - "Rt_rv", - base_rv=process.SimpleRandomWalkProcess( - name="log_rt", - step_rv=metaclass.DistributionalRV( - dist.Normal(0, 0.025), "rw_step_rv" - ), - init_rv=metaclass.DistributionalRV( - dist.Normal(0, 0.2), "init_log_Rt_rv" - ), - ), - transforms=transformation.ExpTransform(), -) + + +class MyRt(metaclass.RandomVariable): + + def validate(self): + pass + + def sample(self, n_timepoints: int, **kwargs) -> tuple: + sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) + + rt_rv = metaclass.TransformedRandomVariable( + "Rt_rv", + base_rv=process.SimpleRandomWalkProcess( + name="log_rt", + step_rv=metaclass.DistributionalRV( + dist.Normal(0, sd_rt), "rw_step_rv" + ), + init_rv=metaclass.DistributionalRV( + dist.Normal(0, 0.2), "init_log_Rt_rv" + ), + ), + transforms=transformation.ExpTransform(), + ) + + return rt_rv.sample(n_timepoints=n_timepoints, **kwargs) + + +rtproc = MyRt() # The observation model obs = observation.NegativeBinomialObservation( @@ -213,13 +229,13 @@ Let's simulate to check if the model is working: ```{python} # | label: simulation -import numpyro as npro +import numpyro import numpy as np timeframe = 120 np.random.seed(223) -with npro.handlers.seed(rng_seed=np.random.randint(1, timeframe)): +with numpyro.handlers.seed(rng_seed=np.random.randint(1, timeframe)): sim_data = hosp_model.sample(n_timepoints_to_simulate=timeframe) ``` @@ -253,7 +269,7 @@ We can fit the model to the data. We will use the `run` method of the model obje # | label: model-fit import jax -npro.set_host_device_count(jax.local_device_count()) +numpyro.set_host_device_count(jax.local_device_count()) hosp_model.run( num_samples=1000, num_warmup=1000, @@ -478,7 +494,6 @@ Note a similar weekday effect is implemented in its own module, with example cod ```{python} # | label: weekly-effect from pyrenew import metaclass -import numpyro as npro class DayOfWeekEffect(metaclass.RandomVariable): @@ -499,9 +514,9 @@ class DayOfWeekEffect(metaclass.RandomVariable): return None def sample(self, **kwargs): - ans = npro.sample( + ans = numpyro.sample( name="dayofweek_effect", - fn=npro.distributions.TruncatedNormal( + fn=numpyro.distributions.TruncatedNormal( loc=1.0, scale=0.5, low=0.1, high=10.0 ), sample_shape=(7,), From e714bdd14dc0f4bcb3feb7cfddcd055fc17ae756 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 12:24:50 -0400 Subject: [PATCH 08/33] Tutorial prior tweaks --- docs/source/tutorials/basic_renewal_model.qmd | 9 ++++----- model/src/pyrenew/model/rtinfectionsrenewalmodel.py | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index f66746c5..2f6aa795 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -114,13 +114,13 @@ To initialize these five components within the renewal modeling framework, we es ```{python} # | label: creating-elements # (1) The generation interval (deterministic) -pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) +pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1]) gen_int = DeterministicPMF(pmf_array, name="gen_int") # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + DistributionalRV(dist=dist.LogNormal(1, 1), name="I0"), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -141,7 +141,7 @@ class MyRt(RandomVariable): name="log_rt", step_rv=DistributionalRV(dist.Normal(0, sd_rt), "rw_step_rv"), init_rv=DistributionalRV( - dist.Normal(0, 0.2), "init_log_Rt_rv" + dist.Normal(jnp.log(1.2), jnp.log(1.5)), "init_log_Rt_rv" ), ), transforms=t.ExpTransform(), @@ -197,8 +197,7 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R ```{python} # | label: simulate -np.random.seed(223) -with numpyro.handlers.seed(rng_seed=np.random.randint(1, 60)): +with numpyro.handlers.seed(rng_seed=353): sim_data = model1.sample(n_timepoints_to_simulate=30) sim_data diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index f6f9846f..21fca320 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -6,7 +6,7 @@ from typing import NamedTuple import jax.numpy as jnp -import numpyro as npro +import numpyro import pyrenew.arrayutils as au from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation @@ -223,7 +223,7 @@ def sample( all_latent_infections = jnp.hstack( [I0, post_initialization_latent_infections] ) - npro.deterministic("all_latent_infections", all_latent_infections) + numpyro.deterministic("all_latent_infections", all_latent_infections) if observed_infections is not None: observed_infections = au.pad_x_to_match_y( @@ -239,7 +239,7 @@ def sample( jnp.nan, pad_direction="start", ) - npro.deterministic("Rt", Rt) + numpyro.deterministic("Rt", Rt) return RtInfectionsRenewalSample( Rt=Rt, From 530870cf1caf84cfe56943aaee09278464c003a2 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 12:52:04 -0400 Subject: [PATCH 09/33] More tutorial prior tweaks --- docs/source/tutorials/basic_renewal_model.qmd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 2f6aa795..d1efa0c0 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -120,7 +120,7 @@ gen_int = DeterministicPMF(pmf_array, name="gen_int") # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(dist=dist.LogNormal(1, 1), name="I0"), + DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -141,7 +141,7 @@ class MyRt(RandomVariable): name="log_rt", step_rv=DistributionalRV(dist.Normal(0, sd_rt), "rw_step_rv"), init_rv=DistributionalRV( - dist.Normal(jnp.log(1.2), jnp.log(1.5)), "init_log_Rt_rv" + dist.Normal(jnp.log(1.5), jnp.log(1.2)), "init_log_Rt_rv" ), ), transforms=t.ExpTransform(), From 558d84a0951f667b0550fc35ceb8e22a41ca180b Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 14:09:47 -0400 Subject: [PATCH 10/33] Don't plot nan-padded Rt values --- docs/source/tutorials/basic_renewal_model.qmd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index d1efa0c0..25b89471 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -275,10 +275,10 @@ diagnostic_stats_summary = az.summary( kind="diagnostics", ) -print(diagnostic_stats_summary[:10]) +print(diagnostic_stats_summary[4:14]) ``` -Below we use `plot_trace` to inspect the trace of the first 10 $R_t$ estimates. +Below we use `plot_trace` to inspect the trace of the first 10 inferred $\mathcal{R}(t)$ values. ```{python} # | label: fig-trace-Rt From e6862835e99388a14b9e623852b51d638b16eeaa Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 14:29:57 -0400 Subject: [PATCH 11/33] $ to $\mathcal{R}(t)$ throughout basic renewal tutorial --- docs/source/tutorials/basic_renewal_model.qmd | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 25b89471..d1a6fee0 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -87,13 +87,13 @@ flowchart LR ``` -The pyrenew package models the real-time reproductive number $R_t$, the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components: +The pyrenew package models the real-time reproductive number $\mathcal{R}(t)$, the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components: (1) generation interval, the times between infections (2) initial infections, occurring prior to time $t = 0$ -(3) $R_t$, the real-time reproductive number, +(3) $\mathcal{R}(t)$, the real-time reproductive number, (4) latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and @@ -203,7 +203,7 @@ with numpyro.handlers.seed(rng_seed=353): sim_data ``` -To understand what has been accomplished here, visualize an $R_t$ sample path (left panel) and infections over time (right panel): +To understand what has been accomplished here, visualize an $\mathcal{R}(t)$ sample path (left panel) and infections over time (right panel): ```{python} # | label: fig-basic @@ -243,7 +243,7 @@ model1.run( ) ``` -Now, let's investigate the output, particularly the posterior distribution of the $R_t$ estimates: +Now, let's investigate the output, particularly the posterior distribution of the $\mathcal{R}(t)$ estimates: ```{python} # | label: fig-output-rt @@ -266,7 +266,7 @@ import arviz as az idata = az.from_numpyro(model1.mcmc) ``` -and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $R_t$. +and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $\mathcal{R}(t)$. ```{python} # | label: diagnostics @@ -295,7 +295,7 @@ plt.show() ``` -We inspect the posterior distribution of $R_t$ by plotting the 90% and 50% highest density intervals: +We inspect the posterior distribution of $\mathcal{R}(t)$ by plotting the 90% and 50% highest density intervals: ```{python} # | label: fig-hdi-Rt @@ -328,7 +328,7 @@ axes.plot(x_data, mean_Rt[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Effective Reproduction Number", fontsize=10) axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("$R_t$", fontsize=10) +axes.set_ylabel("$\mathcal{R}(t)$", fontsize=10) plt.show() ``` From 7875968fee5ae6ac7bb90ff735867ee62d99e165 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 14:46:55 -0400 Subject: [PATCH 12/33] Better handling of nan padded Rt --- docs/source/tutorials/basic_renewal_model.qmd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index d1a6fee0..72231bbc 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -271,11 +271,11 @@ and use the InferenceData to compute the model-fit diagnostics. Here, we show di ```{python} # | label: diagnostics diagnostic_stats_summary = az.summary( - idata.posterior["Rt"], + idata.posterior["Rt"][4:], # ignore Nan padding kind="diagnostics", ) -print(diagnostic_stats_summary[4:14]) +print(diagnostic_stats_summary[10]) ``` Below we use `plot_trace` to inspect the trace of the first 10 inferred $\mathcal{R}(t)$ values. From 7ed29cd26cc5fe893d98193143763f689b076570 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 18 Jul 2024 15:43:19 -0400 Subject: [PATCH 13/33] Add reparam option to DistributionalRV, use for basic tutorial --- docs/source/tutorials/basic_renewal_model.qmd | 20 ++++++++------ model/src/pyrenew/metaclass.py | 27 ++++++++++++------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 72231bbc..8a098fd4 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -31,6 +31,7 @@ from pyrenew.metaclass import ( TransformedRandomVariable, ) import pyrenew.transformation as t +from numpyro.infer.reparam import LocScaleReparam ``` ## Architecture of `RtInfectionsRenewalModel` @@ -139,14 +140,17 @@ class MyRt(RandomVariable): "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, sd_rt), "rw_step_rv"), + step_rv=DistributionalRV( + dist.Normal(0, sd_rt), + "rw_step_rv", + reparam=LocScaleReparam(0), + ), init_rv=DistributionalRV( - dist.Normal(jnp.log(1.5), jnp.log(1.2)), "init_log_Rt_rv" + dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv" ), ), transforms=t.ExpTransform(), ) - return rt_rv.sample(n_timepoints=n_timepoints, **kwargs) @@ -197,8 +201,8 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R ```{python} # | label: simulate -with numpyro.handlers.seed(rng_seed=353): - sim_data = model1.sample(n_timepoints_to_simulate=30) +with numpyro.handlers.seed(rng_seed=53): + sim_data = model1.sample(n_timepoints_to_simulate=40) sim_data ``` @@ -271,11 +275,11 @@ and use the InferenceData to compute the model-fit diagnostics. Here, we show di ```{python} # | label: diagnostics diagnostic_stats_summary = az.summary( - idata.posterior["Rt"][4:], # ignore Nan padding + idata.posterior["Rt"][::, ::, 4:], # ignore nan padding kind="diagnostics", ) -print(diagnostic_stats_summary[10]) +print(diagnostic_stats_summary) ``` Below we use `plot_trace` to inspect the trace of the first 10 inferred $\mathcal{R}(t)$ values. @@ -288,7 +292,7 @@ plt.rcParams["figure.constrained_layout.use"] = True az.plot_trace( idata.posterior, var_names=["Rt"], - coords={"Rt_dim_0": np.arange(10)}, + coords={"Rt_dim_0": np.arange(4, 14)}, compact=False, ) plt.show() diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 03b417ea..eb3032a6 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -16,6 +16,7 @@ import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive +from numpyro.infer.reparam import Reparam from pyrenew.mcmcutils import plot_posterior, spread_draws from pyrenew.transformation import Transform @@ -220,6 +221,7 @@ def __init__( self, dist: numpyro.distributions.Distribution, name: str, + reparam: Reparam = None, ) -> None: """ Default constructor for DistributionalRV. @@ -231,6 +233,11 @@ def __init__( name : str Name of the random variable. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + Returns ------- None @@ -240,6 +247,10 @@ def __init__( self.dist = dist self.name = name + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} return None @@ -278,15 +289,13 @@ def sample( tuple Containing the sampled from the distribution. """ - return ( - jnp.atleast_1d( - numpyro.sample( - name=self.name, - fn=self.dist, - obs=obs, - ) - ), - ) + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=self.dist, + obs=obs, + ) + return (jnp.atleast_1d(sample),) class Model(metaclass=ABCMeta): From 2ef20752fe91dba48e40d7964c921d018cfa188f Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 19 Jul 2024 15:52:09 -0400 Subject: [PATCH 14/33] Update docs/source/tutorials/hospital_admissions_model.qmd --- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 2f8b56c5..cb3a0735 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -251,7 +251,7 @@ import numpy as np timeframe = 120 -with npro.handlers.seed(rng_seed=223): +with numpyro.handlers.seed(rng_seed=223): simulated_data = hosp_model.sample(n_timepoints_to_simulate=timeframe) ``` From 2c1c7c7d087924661ba60409b86ee290c3481faf Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 19 Jul 2024 16:39:56 -0400 Subject: [PATCH 15/33] Update model/src/pyrenew/metaclass.py Co-authored-by: Damon Bayer --- model/src/pyrenew/metaclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index e136d347..cc7418c2 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -228,7 +228,7 @@ def __init__( Parameters ---------- - dist : dist.Distribution + dist : numpyro.distributions.Distribution Distribution of the random variable. name : str Name of the random variable. From 461796334b432fb24de2f70733b91f9cb398b5db Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 19 Jul 2024 16:40:26 -0400 Subject: [PATCH 16/33] Update model/pyproject.toml --- model/pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/model/pyproject.toml b/model/pyproject.toml index 428efeb3..f7f963a7 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -34,7 +34,6 @@ pytest-cov = "^5.0.0" pytest-mpl = "^0.17.0" numpydoc = "^1.7.0" arviz = "^0.18.0" -quarto = "^0.1.0" [tool.numpydoc_validation] checks = [ From af8a7d94b3f48188b29d88c1130a84ca0ba4c7cb Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 09:27:09 -0400 Subject: [PATCH 17/33] Escape backlash --- docs/source/tutorials/basic_renewal_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 278e42b1..e6a3fdb8 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -332,7 +332,7 @@ axes.plot(x_data, mean_Rt[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Effective Reproduction Number", fontsize=10) axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("$\mathcal{R}(t)$", fontsize=10) +axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10) plt.show() ``` From 58690337565ca5ccc1d5ae04485fbfc07fe8826b Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 09:31:14 -0400 Subject: [PATCH 18/33] Comment on custom RV --- docs/source/tutorials/basic_renewal_model.qmd | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index e6a3fdb8..2b83497f 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -108,7 +108,7 @@ To initialize these five components within the renewal modeling framework, we es (2) an instance of the `InfectionInitializationProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `InitializeInfectionsZeroPad`, the latent infections before this time are assumed to be 0. -(3) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. +(3) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom `RandomVariable`, `MyRt`. (4) an instance of the `Infections` class with default values, and @@ -129,7 +129,8 @@ I0 = InfectionInitializationProcess( ) -# (3) The random walk on log Rt, with an inferred s.d. +# (3) The random walk on log Rt, with an inferred s.d. Here, we +# construct a custom RandomVariable. class MyRt(RandomVariable): def validate(self): @@ -186,7 +187,7 @@ The following diagram summarizes how the modules interact via composition; notab flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionInitializationProcess)"] - rt["(3) rt_proc\n(TransformedRandomVariable)"] + rt["(3) rt_proc\n(MyRt - the custom RV defined above))"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] From 1081ae201f939f51eff0185dbb8287f01c8653b8 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 09:34:16 -0400 Subject: [PATCH 19/33] Rename admissions model test and fix RNG seed pattern --- ..._hospitalizations.py => test_model_hosp_admissions.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename model/src/test/{test_model_hospitalizations.py => test_model_hosp_admissions.py} (98%) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hosp_admissions.py similarity index 98% rename from model/src/test/test_model_hospitalizations.py rename to model/src/test/test_model_hosp_admissions.py index be352bcf..1abb2a3b 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -73,7 +73,7 @@ def sample(self, **kwargs): # numpydoc ignore=GL08 def test_model_hosp_no_timepoints_or_observations(): """ - Checks that the Hospitalization model does not run + Checks that the hospital admissions model does not run without either n_timepoints_to_simulate or observed_admissions """ @@ -130,8 +130,7 @@ def test_model_hosp_no_timepoints_or_observations(): hosp_admission_obs_process_rv=observed_admissions, ) - np.random.seed(223) - with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): + with numpyro.handlers.seed(rng_seed=233): with pytest.raises(ValueError, match="Either"): model1.sample( n_timepoints_to_simulate=None, data_observed_admissions=None @@ -140,7 +139,8 @@ def test_model_hosp_no_timepoints_or_observations(): def test_model_hosp_both_timepoints_and_observations(): """ - Checks that the Hospitalization model does not run with both n_timepoints_to_simulate and observed_admissions passed + Checks that the hospital admissions model does not run with + both n_timepoints_to_simulate and observed_admissions passed """ gen_int = DeterministicPMF( From 74b553dde141ee8515ccd1fefc29a55c33e09099 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 09:41:43 -0400 Subject: [PATCH 20/33] Update model/src/pyrenew/process/simplerandomwalk.py Co-authored-by: Damon Bayer --- model/src/pyrenew/process/simplerandomwalk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index 87d36dab..d793ab37 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -72,7 +72,7 @@ def sample( With a single array of shape (n_timepoints,). """ - init, *_ = self.init_rv.sample(**kwargs) + init, *_ = self.init_rv(**kwargs) def transition(x_prev, _): # numpydoc ignore=GL08 From ba567ba22c14b353c17539de8f3b3685aca098b6 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 09:41:52 -0400 Subject: [PATCH 21/33] Update model/src/pyrenew/process/simplerandomwalk.py Co-authored-by: Damon Bayer --- model/src/pyrenew/process/simplerandomwalk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index d793ab37..de66c3fe 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -76,7 +76,7 @@ def sample( def transition(x_prev, _): # numpydoc ignore=GL08 - diff, *_ = self.step_rv.sample(**kwargs) + diff, *_ = self.step_rv(**kwargs) x_curr = x_prev + diff return x_curr, x_curr From 44857ff0cceb93e62e31243b3d8ddd838a8c1f5d Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 09:53:54 -0400 Subject: [PATCH 22/33] Update other mermaid diagram --- docs/source/tutorials/basic_renewal_model.qmd | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 2b83497f..67274f72 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -65,7 +65,7 @@ flowchart LR end subgraph process[Process module] - rt["rt_proc\n(RtRandomWalkProcess)"] + rt["SimpleRandomWalk, which we fill use to create a custom RandomVariable to represent R(t)"] end subgraph deterministic[Deterministic module] @@ -96,7 +96,7 @@ The pyrenew package models the real-time reproductive number $\mathcal{R}(t)$, t (2) initial infections, occurring prior to time $t = 0$ -(3) $\mathcal{R}(t)$, the real-time reproductive number, +(3) $\mathcal{R}(t)$, the time-varying reproductive number, (4) latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and @@ -187,7 +187,7 @@ The following diagram summarizes how the modules interact via composition; notab flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionInitializationProcess)"] - rt["(3) rt_proc\n(MyRt - the custom RV defined above))"] + rt["(3) rt_proc\n(MyRt, the custom RV defined above)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] From 8ee433f2212f4255102780697265630ca5e581c2 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 12:40:39 -0400 Subject: [PATCH 23/33] Linear scale simulated admissions, with points --- docs/source/tutorials/hospital_admissions_model.qmd | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index cb3a0735..d3537c87 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -267,9 +267,8 @@ axs[0].plot(simulated_data.Rt) axs[0].set_ylabel("Simulated Rt") # Admissions plot -axs[1].plot(simulated_data.observed_hosp_admissions) +axs[1].plot(simulated_data.observed_hosp_admissions, "-o") axs[1].set_ylabel("Simulated Admissions") -axs[1].set_yscale("log") fig.suptitle("Basic renewal model") fig.supxlabel("Time") From a0fc6f79707d6d481413b1643ffa0e8f233dd067 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 14:15:27 -0400 Subject: [PATCH 24/33] Fix composition diagram --- docs/source/tutorials/basic_renewal_model.qmd | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 67274f72..78234578 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -56,20 +56,20 @@ flowchart LR models((Model\nmetaclass)) subgraph observations[Observations module] - obs["observation_process\n(PoissonObservation)"] + obs["infection_obs_process_rv\n(PoissonObservation)"] end subgraph latent[Latent module] - inf["latent_infections\n(Infections)"] - i0["I0\n(DistributionalRV)"] + inf["latent_infections_rv\n(Infections)"] + i0["I0_rv\n(DistributionalRV)"] end subgraph process[Process module] - rt["SimpleRandomWalk, which we fill use to create a custom RandomVariable to represent R(t)"] + rt["Rt_process_rv\n(Custom class built using SimpleRandomWalk)"] end subgraph deterministic[Deterministic module] - detpmf["gen_int\n(DeterministicPMF)"] + detpmf["gen_int_rv\n(DeterministicPMF)"] end subgraph model[Model module] From 931977b7fee00efaae7d65900de2036964235aed Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 14:21:58 -0400 Subject: [PATCH 25/33] Hospitalization => hospital admission throughout --- .../source/tutorials/hospital_admissions_model.qmd | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index d3537c87..d100a5f0 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -16,7 +16,7 @@ This document illustrates how a hospital admissions-only model can be fitted usi ## Model definition -In this section, we provide the formal definition of the model. The hospitalization model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: +In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: $$ h(t) \sim \text{HospDist}\left(H(t)\right) @@ -33,9 +33,9 @@ H(t) & = p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) \\ \end{align*} $$ -Were $d(\tau)$ is the infection to hospitalization interval, $I(t)$ is the number of latent infections at time $t$, $p_\mathrm{hosp}(t)$ is the infection to hospitalization rate. +Were $d(\tau)$ is the infection to hospital admission interval, $I(t)$ is the number of latent infections at time $t$, $p_\mathrm{hosp}(t)$ is the infection to admission rate. -The number of latent hospital admissions at time $t$ is a function of the number of latent infections at time $t$ and the infection to hospitalization rate. The latent infections are modeled as a renewal process: +The number of latent hospital admissions at time $t$ is a function of the number of latent infections at time $t$ and the infection to admission rate. The latent infections are modeled as a renewal process: $$ \begin{align*} @@ -104,11 +104,11 @@ plt.show() ## Building the model -First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospitalization interval. +First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospital admission interval. ```{python} # | label: fig-data-extract -# | fig-cap: Generation interval and infection to hospitalization interval +# | fig-cap: Generation interval and infection to hospital admission interval gen_int = datasets.load_generation_interval() inf_hosp_int = datasets.load_infection_admission_interval() @@ -126,7 +126,7 @@ fig, axs = plt.subplots(1, 2) axs[0].plot(gen_int) axs[0].set_title("Generation interval") axs[1].plot(inf_hosp_int) -axs[1].set_title("Infection to hospitalization interval") +axs[1].set_title("Infection to hospital admission interval") plt.show() ``` @@ -153,7 +153,7 @@ latent_hosp = latent.HospitalAdmissions( ) ``` -The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospitalization interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospitalization rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospitalization rate. Now, we can define the rest of the other components: +The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: ```{python} # | label: initializing-rest-of-model From 6e7a5e7f80f66977a64ba30f1a400ca4bd858930 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 14:54:37 -0400 Subject: [PATCH 26/33] np.mean to arviz mean for posterior means, and include both chains --- .../source/tutorials/hospital_admissions_model.qmd | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index d100a5f0..a6f4453c 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -343,6 +343,7 @@ Below we plot 90% and 50% highest density intervals for latent hospital admissio x_data = idata.posterior["latent_hospital_admissions_dim_0"] y_data = idata.posterior["latent_hospital_admissions"] + fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, @@ -365,10 +366,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_hosp_admission = np.mean( - idata.posterior["latent_hospital_admissions"], axis=1 -) -axes.plot(x_data, mean_latent_hosp_admission[0], color="C0", label="Mean") +mean_latent_hosp_admission = y_data.mean(dim=["chain", "draw"]) + +axes.plot(x_data, mean_latent_hosp_admission, color="C0", label="Mean") axes.legend() axes.set_title("Posterior Hospital Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) @@ -547,13 +547,11 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean( - idata_weekday.posterior_predictive["negbinom_rv"], axis=1 -) +mean_latent_infection = y_data.mean(dim=["chain", "draw"]) plt.plot( idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - mean_latent_infection[0], + mean_latent_infection, color="C0", label="Mean", ) From 088c28fc9989924b76fd6929ee4c9674cb3df06c Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 14:56:28 -0400 Subject: [PATCH 27/33] np.mean to arviz mean for posterior means, and include both chains in basic renewal --- docs/source/tutorials/basic_renewal_model.qmd | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 78234578..9d0fe97c 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -328,8 +328,8 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_Rt = np.mean(idata.posterior["Rt"], axis=1) -axes.plot(x_data, mean_Rt[0], color="C0", label="Mean") +mean_Rt = y_data.mean(dim=["chain", "draw"]) +axes.plot(x_data, mean_Rt, color="C0", label="Mean") axes.legend() axes.set_title("Posterior Effective Reproduction Number", fontsize=10) axes.set_xlabel("Time", fontsize=10) @@ -367,10 +367,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean( - idata.posterior["all_latent_infections"], axis=1 -) -axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") +mean_latent_infections = y_data.mean(dim=["chain", "draw"]) + +axes.plot(x_data, mean_latent_infections, color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) axes.set_xlabel("Time", fontsize=10) From c1f1cc96ecb6e65cb24b4c2383ccb679a4123447 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 14:59:04 -0400 Subject: [PATCH 28/33] Fix typo --- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index a6f4453c..5496597c 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -117,7 +117,7 @@ gen_int_array = gen_int["probability_mass"].to_numpy() gen_int = gen_int_array inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() -# Taking a pick at the first 5 elements of each +# Taking a peek at the first 5 elements of each gen_int[:5], inf_hosp_int[:5] # Visualizing both quantities side by side From a033b7d1ca904231ea16fbea44290db2e0c0542a Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 15:11:11 -0400 Subject: [PATCH 29/33] means to medians in hospital_admissions tutorial, clarify plotting code --- .../tutorials/hospital_admissions_model.qmd | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 5496597c..edb60ef7 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -365,10 +365,10 @@ az.plot_hdi( ax=axes, ) -# Add mean of the posterior to the figure -mean_latent_hosp_admission = y_data.mean(dim=["chain", "draw"]) +# Add the posterior median to the figure +median_ts = y_data.median(dim=["chain", "draw"]) -axes.plot(x_data, mean_latent_hosp_admission, color="C0", label="Mean") +axes.plot(x_data, median_ts, color="C0", label="Median") axes.legend() axes.set_title("Posterior Hospital Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) @@ -414,6 +414,11 @@ az.plot_hdi( fill_kwargs={"alpha": 0.6}, ax=axes, ) + +# Add the posterior median to the figure +median_ts = y_data.median(dim=["chain", "draw"]) +axes.plot(x_data, median_ts, color="C0", label="Median") +axes.legend() ``` @@ -523,12 +528,14 @@ And now we plot the posterior predictive distributions with a `{python} n_foreca ```{python} # | label: fig-output-posterior-predictive-forecast # | fig-cap: Posterior predictive admissions, including a forecast. +x_data = ( + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size() +) +y_data = idata_weekday.posterior_predictive["negbinom_rv"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - hdi_data=compute_eti( - idata_weekday.posterior_predictive["negbinom_rv"], 0.9 - ), + x_data, + hdi_data=compute_eti(y_data, 0.9), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, @@ -536,24 +543,22 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - hdi_data=compute_eti( - idata_weekday.posterior_predictive["negbinom_rv"], 0.5 - ), + x_data, + hdi_data=compute_eti(y_data, 0.5), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) -# Add mean of the posterior to the figure -mean_latent_infection = y_data.mean(dim=["chain", "draw"]) +# Add median of the posterior to the figure +median_ts = y_data.median(dim=["chain", "draw"]) plt.plot( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size(), - mean_latent_infection, + x_data, + median_ts, color="C0", - label="Mean", + label="Median", ) plt.scatter( idata_weekday.observed_data["negbinom_rv_dim_0"] + gen_int.size(), From 0afa30ea4bcf4454a41c3fc6581dcb11ec7f7b07 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 15:18:39 -0400 Subject: [PATCH 30/33] Means to medians in basic_renewal tutorial, avoid issues with data displacement --- docs/source/tutorials/basic_renewal_model.qmd | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 9d0fe97c..336e9701 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -305,8 +305,8 @@ We inspect the posterior distribution of $\mathcal{R}(t)$ by plotting the 90% an ```{python} # | label: fig-hdi-Rt # | fig-cap: High density interval for Effective Reproduction Number -x_data = idata.posterior["Rt_dim_0"] -y_data = idata.posterior["Rt"] +x_data = idata.posterior["Rt_dim_0"][4:] +y_data = idata.posterior["Rt"][::, ::, 4:] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( @@ -328,8 +328,8 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_Rt = y_data.mean(dim=["chain", "draw"]) -axes.plot(x_data, mean_Rt, color="C0", label="Mean") +median_ts = y_data.median(dim=["chain", "draw"]) +axes.plot(x_data, median_ts, color="C0", label="Median") axes.legend() axes.set_title("Posterior Effective Reproduction Number", fontsize=10) axes.set_xlabel("Time", fontsize=10) @@ -366,10 +366,10 @@ az.plot_hdi( ax=axes, ) -# Add mean of the posterior to the figure -mean_latent_infections = y_data.mean(dim=["chain", "draw"]) +# plot the posterior median +median_ts = y_data.median(dim=["chain", "draw"]) +axes.plot(x_data, median_ts, color="C0", label="Median") -axes.plot(x_data, mean_latent_infections, color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) axes.set_xlabel("Time", fontsize=10) From f23adb5b8abffa5fa60670b15be89cc35bd3b1ef Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 16:31:14 -0400 Subject: [PATCH 31/33] n_timepoints ==> n_steps as argument to SimpleRandomWalkProcess --- docs/source/tutorials/basic_renewal_model.qmd | 2 +- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- model/src/pyrenew/model/rtinfectionsrenewalmodel.py | 2 +- model/src/pyrenew/process/simplerandomwalk.py | 8 ++++---- model/src/test/test_latent_admissions.py | 2 +- model/src/test/test_latent_infections.py | 2 +- model/src/test/test_random_walk.py | 6 +++--- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 336e9701..7c170af2 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -154,7 +154,7 @@ class MyRt(RandomVariable): ), transforms=t.ExpTransform(), ) - return rt_rv.sample(n_timepoints=n_timepoints, **kwargs) + return rt_rv.sample(n_steps=n_timepoints, **kwargs) rt_proc = MyRt() diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index edb60ef7..15066834 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -204,7 +204,7 @@ class MyRt(metaclass.RandomVariable): transforms=transformation.ExpTransform(), ) - return rt_rv.sample(n_timepoints=n_timepoints, **kwargs) + return rt_rv.sample(n_steps=n_timepoints, **kwargs) rtproc = MyRt() diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 21fca320..8db8bd6e 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -194,7 +194,7 @@ def sample( # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) Rt, *_ = self.Rt_process_rv( - n_timepoints=n_timepoints, + n_steps=n_timepoints, **kwargs, ) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index de66c3fe..cc396192 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -52,7 +52,7 @@ def __init__( def sample( self, - n_timepoints: int, + n_steps: int, **kwargs, ) -> tuple: """ @@ -60,7 +60,7 @@ def sample( Parameters ---------- - n_timepoints : int + n_steps : int Length of the walk to sample. **kwargs : dict, optional Additional keyword arguments passed through to internal sample() @@ -69,7 +69,7 @@ def sample( Returns ------- tuple - With a single array of shape (n_timepoints,). + With a single array of shape (n_steps,). """ init, *_ = self.init_rv(**kwargs) @@ -83,7 +83,7 @@ def transition(x_prev, _): _, x = scan( transition, init=init, - xs=jnp.arange(n_timepoints - 1), + xs=jnp.arange(n_steps - 1), ) return (jnp.hstack([init, x.flatten()]),) diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index a9e2bb5c..9033e9bf 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -33,7 +33,7 @@ def test_admissions_sample(): ) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_rt, *_ = rt(n_timepoints=30) + sim_rt, *_ = rt(n_steps=30) gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) i0 = 10 * jnp.ones_like(gen_int) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 04638d59..f330464f 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -31,7 +31,7 @@ def test_infections_as_deterministic(): ) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_rt, *_ = rt(n_timepoints=30) + sim_rt, *_ = rt(n_steps=30) gen_int = jnp.array([0.25, 0.25, 0.25, 0.25]) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index d6d6ad09..66be96db 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -30,8 +30,8 @@ def test_rw_can_be_sampled(): with numpyro.handlers.seed(rng_seed=62): # can sample with a fixed init # and with a random init - ans_rand = rw_init_rand(n_timepoints=3532) - ans_fixed = rw_init_fixed(n_timepoints=5023) + ans_rand = rw_init_rand(n_steps=3532) + ans_fixed = rw_init_fixed(n_steps=5023) # check that the samples are of the right shape assert ans_rand[0].shape == (3532,) @@ -63,7 +63,7 @@ def test_rw_samples_correctly_distributed(): ) with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal(n_timepoints=n_samples) + samples, *_ = rw_normal(n_steps=n_samples) # Checking the shape assert samples.shape == (n_samples,) From 40aefbb4cd9a5d3f368aa2936935779626dad82a Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 16:38:43 -0400 Subject: [PATCH 32/33] update pyrenew_demo.qmd to use n_steps --- docs/source/tutorials/pyrenew_demo.qmd | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index edba0e75..598308c8 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -41,20 +41,19 @@ from pyrenew.process import SimpleRandomWalkProcess from pyrenew.metaclass import DistributionalRV ``` -To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the n_timepoints of the block that follows. Inside the `with` block, the `q_samp = q(n_timepoints=100)` generates the sample instance over a n_timepoints of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. +To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the block that follows. Inside the `with` block, the `q_samp = q(n_steps=100)` generates the sample instance over a `n_steps` period of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. ```{python} # | label: fig-randwalk # | fig-cap: Random walk example -np.random.seed(3312) q = SimpleRandomWalkProcess( "example_random_walk", step_rv=DistributionalRV(dist.Normal(0, 0.001), "step_rv"), init_rv=DistributionalRV(dist.Normal(0, 0.001), "init_rv"), ) -with seed(rng_seed=np.random.randint(0, 1000)): - q_samp = q(n_timepoints=100) +with seed(rng_seed=325): + q_samp = q(n_steps=100) plt.plot(np.exp(q_samp[0])) ``` From 71bd14678891e1b84cf2ed9b75529f27e2037d0f Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 22 Jul 2024 16:53:17 -0400 Subject: [PATCH 33/33] Update tutorials --- docs/source/tutorials/basic_renewal_model.qmd | 4 ++-- docs/source/tutorials/hospital_admissions_model.qmd | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 7c170af2..fa03f957 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -136,7 +136,7 @@ class MyRt(RandomVariable): def validate(self): pass - def sample(self, n_timepoints: int, **kwargs) -> tuple: + def sample(self, n_steps: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) rt_rv = TransformedRandomVariable( @@ -154,7 +154,7 @@ class MyRt(RandomVariable): ), transforms=t.ExpTransform(), ) - return rt_rv.sample(n_steps=n_timepoints, **kwargs) + return rt_rv.sample(n_steps=n_steps, **kwargs) rt_proc = MyRt() diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 15066834..7e89df33 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -187,7 +187,7 @@ class MyRt(metaclass.RandomVariable): def validate(self): pass - def sample(self, n_timepoints: int, **kwargs) -> tuple: + def sample(self, n_steps: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) rt_rv = metaclass.TransformedRandomVariable( @@ -204,7 +204,7 @@ class MyRt(metaclass.RandomVariable): transforms=transformation.ExpTransform(), ) - return rt_rv.sample(n_steps=n_timepoints, **kwargs) + return rt_rv.sample(n_steps=n_steps, **kwargs) rtproc = MyRt()