diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 2545aee1..eec79ec2 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -224,11 +224,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(sim_data.Rt) +axs[0].plot(sim_data.Rt.value) axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(sim_data.observed_infections) +axs[1].plot(sim_data.observed_infections.value) axs[1].set_ylabel("Infections") fig.suptitle("Basic renewal model") @@ -246,7 +246,7 @@ import jax model1.run( num_warmup=2000, num_samples=1000, - data_observed_infections=sim_data.observed_infections, + data_observed_infections=sim_data.observed_infections.value, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), ) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 48e1f80e..664cd3bd 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -42,15 +42,17 @@ The following code-chunk defines the model components. Notice that for both the ```{python} # | label: model-components gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1]) + gen_int = DeterministicPMF(name="gen_int", value=gen_int_array) -feedback_strength = DeterministicVariable(name="feedback_strength", value=0.05) +feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01) + I0 = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( gen_int_array.size, - DeterministicVariable(name="rate", value=0.5), + DeterministicVariable(name="rate", value=0.05), ), t_unit=1, ) @@ -103,7 +105,7 @@ with numpyro.handlers.seed(rng_seed=223): import matplotlib.pyplot as plt fig, ax = plt.subplots() -ax.plot(model0_samp.latent_infections) +ax.plot(model0_samp.latent_infections.value) ax.set_xlabel("Time") ax.set_ylabel("Infections") plt.show() @@ -160,7 +162,7 @@ The next step is to create the actual class. The bulk of its implementation lies # | label: new-model-def # | code-line-numbers: true # Creating the class -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.latent import compute_infections_from_rt_with_feedback from pyrenew import arrayutils as au from jax.typing import ArrayLike @@ -208,12 +210,14 @@ class InfFeedback(RandomVariable): **kwargs, ) inf_feedback_strength = au.pad_x_to_match_y( - x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0] + x=inf_feedback_strength.value, + y=Rt, + fill_value=inf_feedback_strength.value[0], ) # Sampling inf feedback and adjusting the shape inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value) # Generating the infections with feedback all_infections, Rt_adj = compute_infections_from_rt_with_feedback( @@ -230,8 +234,8 @@ class InfFeedback(RandomVariable): # Preparing theoutput return InfFeedbackSample( - infections=all_infections, - rt=Rt_adj, + infections=SampledValue(all_infections), + rt=SampledValue(Rt_adj), ) ``` @@ -273,8 +277,8 @@ Comparing `model0` with `model1`, these two should match: import matplotlib.pyplot as plt fig, ax = plt.subplots(ncols=2) -ax[0].plot(model0_samp.latent_infections) -ax[1].plot(model1_samp.latent_infections) +ax[0].plot(model0_samp.latent_infections.value) +ax[1].plot(model1_samp.latent_infections.value) ax[0].set_xlabel("Time (model 0)") ax[1].set_xlabel("Time (model 1)") ax[0].set_ylabel("Infections") diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 478f00f2..32bab32d 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -199,7 +199,7 @@ class MyRt(metaclass.RandomVariable): base_rv=process.SimpleRandomWalkProcess( name="log_rt", step_rv=metaclass.DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, sd_rt) + name="rw_step_rv", dist=dist.Normal(0, sd_rt.value) ), init_rv=metaclass.DistributionalRV( name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) @@ -272,11 +272,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(simulated_data.Rt) +axs[0].plot(simulated_data.Rt.value) axs[0].set_ylabel("Simulated Rt") # Admissions plot -axs[1].plot(simulated_data.observed_hosp_admissions, "-o") +axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o") axs[1].set_ylabel("Simulated Admissions") fig.suptitle("Basic renewal model") diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 84702ac2..2cd1db8a 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20): # Plotting the Rt values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.rt)), sim_data.rt, where="post") +plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post") plt.xlabel("Time") plt.ylabel("Rt") plt.title("Simulated Rt values") @@ -92,7 +92,9 @@ with numpyro.handlers.seed(rng_seed=20): # Plotting the effect values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.value)), sim_data.value, where="post") +plt.step( + np.arange(len(sim_data.value.value)), sim_data.value.value, where="post" +) plt.xlabel("Time") plt.ylabel("Effect size") plt.title("Simulated Day of Week Effect values") diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index 13834a59..9a2263fb 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -10,9 +10,16 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. - `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. +`pyrenew` deals with time by having `RandomVariable`s carry information about -The tuple `(t_unit, t_start)` can encode different types of time series data. For example: +1. their own time unit expressed relative to the fundamental unit (`t_unit`) and +2. the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. + +Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes. + +By default, `SampledValue` objects carry the `t_start` and `t_unit` of the `RandomVariable` from which they are `sample()`-d. One might override this default to allow a `RandomVariable.sample()` call to produce multiple `SampledValue`s with different time-units, or with different start-points relative to the `RandomVariable`'s own `t_start`. + +The `t_unit, t_start` pair can encode different types of time series data. For example: | Description | `t_unit` | `t_start` | |:-----------------|----------------:|-----------------:| @@ -31,10 +38,6 @@ The `PeriodicBroadcaster()` class provides a way of tiling and repeating data ac The following section describes some preliminary design principles that may be included in future versions of `pyrenew`. -### Validation - -With random variables possibly spanning different time scales, *e.g.*, weekly, daily, hourly, the metaclass `Model` should ensure random variables within the model share the same time unit. - ### Array alignment Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the initialization process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.: @@ -42,3 +45,7 @@ Using `t_unit` and `t_start`, random variables should be able to align input and ```python Rt_aligned, infections_aligned = align([Rt, infections]) ``` + +### Retrieving time information from sites + +Future versions of `pyrenew` could include a way to retrieve the time information for sites keyed by site name the model. diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index c9ff9d8f..2bb03333 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpyro from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class DeterministicVariable(RandomVariable): @@ -19,24 +19,30 @@ def __init__( self, name: str, value: ArrayLike, + t_start: int | None = None, + t_unit: int | None = None, ) -> None: """Default constructor Parameters ---------- name : str - A name to assign to the process. + A name to assign to the variable. value : ArrayLike An ArrayLike object. + t_start : int, optional + The start time of the variable, if any. + t_unit : int, optional + The unit of time relative to the model's fundamental (smallest) time unit, if any Returns ------- None """ - self.name = name self.value = jnp.atleast_1d(value) self.validate(value) + self.set_timeseries(t_start, t_unit) return None @@ -75,16 +81,27 @@ def sample( Parameters ---------- record : bool, optional - Whether to record the value of the deterministic RandomVariable. Defaults to True. + Whether to record the value of the deterministic + RandomVariable. Defaults to True. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. Returns ------- - tuple - Containing the stored values during construction. + tuple[SampledValue] + A length-one tuple whose single entry is a + :class:`SampledValue` + instance with `value=self.value`, + `t_start=self.t_start`, and + `t_unit=self.t_unit`. """ if record: numpyro.deterministic(self.name, self.value) - return (self.value,) + return ( + SampledValue( + value=self.value, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index 04fdb009..3e4611c2 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -18,6 +18,8 @@ def __init__( name: str, value: ArrayLike, tol: float = 1e-5, + t_start: int | None = None, + t_unit: int | None = None, ) -> None: """ Default constructor @@ -36,6 +38,11 @@ def __init__( tol : float, optional Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults to 1e-5. + t_start : int, optional + The start time of the process. + t_unit : int, optional + The unit of time relative to the model's fundamental (smallest) + time unit. Returns ------- @@ -46,7 +53,12 @@ def __init__( tol=tol, ) - self.basevar = DeterministicVariable(name=name, value=value) + self.basevar = DeterministicVariable( + name=name, + value=value, + t_start=t_start, + t_unit=t_unit, + ) return None @@ -82,7 +94,7 @@ def sample( Returns ------- tuple - Containing the stored values during construction. + Containing the stored values during construction wrapped in a SampledValue. """ return self.basevar.sample(**kwargs) diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index f25e1929..4fc851c8 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -4,6 +4,7 @@ from jax.typing import ArrayLike from pyrenew.deterministic.deterministic import DeterministicVariable +from pyrenew.metaclass import SampledValue class NullVariable(DeterministicVariable): @@ -46,10 +47,10 @@ def sample( Returns ------- tuple - Containing None. + Containing a SampledValue with None. """ - return (None,) + return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) class NullProcess(NullVariable): @@ -95,10 +96,10 @@ def sample( Returns ------- tuple - Containing None. + Containing a SampledValue with None. """ - return (None,) + return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) class NullObservation(NullVariable): @@ -148,7 +149,7 @@ def sample( Returns ------- tuple - Containing None. + Containing a SampledValue with None. """ - return (None,) + return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 64f5a514..1f9bff53 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -2,6 +2,7 @@ import jax.numpy as jnp from pyrenew.deterministic.deterministic import DeterministicVariable +from pyrenew.metaclass import SampledValue class DeterministicProcess(DeterministicVariable): @@ -28,15 +29,24 @@ def sample( Returns ------- - tuple - Containing the stored values during construction. + tuple[SampledValue] + containing the deterministic value(s) provided + at construction as a series of length `duration`. """ res, *_ = super().sample(**kwargs) - dif = duration - res.shape[0] + dif = duration - res.value.shape[0] if dif > 0: - return (jnp.hstack([res, jnp.repeat(res[-1], dif)]),) + value = jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]) + else: + value = res.value[:duration] - return (res[:duration],) + res = SampledValue( + value, + t_start=self.t_start, + t_unit=self.t_unit, + ) + + return (res,) diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index a7f5dbb1..a6ad5cb1 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -9,7 +9,7 @@ import numpyro from jax.typing import ArrayLike from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class HospitalAdmissionsSample(NamedTuple): @@ -18,14 +18,14 @@ class HospitalAdmissionsSample(NamedTuple): Attributes ---------- - infection_hosp_rate : float, optional + infection_hosp_rate : SampledValue, optional The infection-to-hospitalization rate. Defaults to None. - latent_hospital_admissions : ArrayLike or None + latent_hospital_admissions : SampledValue or None The computed number of hospital admissions. Defaults to None. """ - infection_hosp_rate: float | None = None - latent_hospital_admissions: ArrayLike | None = None + infection_hosp_rate: SampledValue | None = None + latent_hospital_admissions: SampledValue | None = None def __repr__(self): return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions})" @@ -158,7 +158,7 @@ def sample( Parameters ---------- - latent : ArrayLike + latent_infections : ArrayLike Latent infections. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` @@ -171,7 +171,7 @@ def sample( infection_hosp_rate, *_ = self.infect_hosp_rate_rv(**kwargs) - infection_hosp_rate_t = infection_hosp_rate * latent_infections + infection_hosp_rate_t = infection_hosp_rate.value * latent_infections ( infection_to_admission_interval, @@ -180,19 +180,22 @@ def sample( latent_hospital_admissions = jnp.convolve( infection_hosp_rate_t, - infection_to_admission_interval, + infection_to_admission_interval.value, mode="full", )[: infection_hosp_rate_t.shape[0]] # Applying the day of the week effect latent_hospital_admissions = ( latent_hospital_admissions - * self.day_of_week_effect_rv(**kwargs)[0] + * self.day_of_week_effect_rv( + n_timepoints=latent_hospital_admissions.size, **kwargs + )[0].value ) - # Applying probability of hospitalization effect + # Applying reporting probability latent_hospital_admissions = ( - latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0] + latent_hospital_admissions + * self.hosp_report_prob_rv(**kwargs)[0].value ) numpyro.deterministic( @@ -200,5 +203,10 @@ def sample( ) return HospitalAdmissionsSample( - infection_hosp_rate, latent_hospital_admissions + infection_hosp_rate=infection_hosp_rate, + latent_hospital_admissions=SampledValue( + value=latent_hospital_admissions, + t_start=self.t_start, + t_unit=self.t_unit, + ), ) diff --git a/model/src/pyrenew/latent/infection_initialization_method.py b/model/src/pyrenew/latent/infection_initialization_method.py index 62de52c0..3f58d93e 100644 --- a/model/src/pyrenew/latent/infection_initialization_method.py +++ b/model/src/pyrenew/latent/infection_initialization_method.py @@ -176,7 +176,7 @@ def initialize_infections(self, I_pre_init: ArrayLike): raise ValueError( f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}." ) - (rate,) = self.rate() + rate = self.rate()[0].value if rate.size != 1: raise ValueError( f"rate must be an array of size 1. Got size {rate.size}." diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index a92e6ade..8e8c62e2 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -4,7 +4,7 @@ from pyrenew.latent.infection_initialization_method import ( InfectionInitializationMethod, ) -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class InfectionInitializationProcess(RandomVariable): @@ -93,10 +93,21 @@ def sample(self) -> tuple: Returns ------- tuple - a tuple where the only element is an array with the number of initialized infections at each time point. + a tuple where the only element is an array with + the number of initialized infections at each time point. """ + (I_pre_init,) = self.I_pre_init_rv() - infection_initialization = self.infection_init_method(I_pre_init) + + infection_initialization = self.infection_init_method( + I_pre_init.value, + ) numpyro.deterministic(self.name, infection_initialization) - return (infection_initialization,) + return ( + SampledValue( + infection_initialization, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 4780202d..e5da11d6 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import pyrenew.latent.infection_functions as inf from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class InfectionsSample(NamedTuple): @@ -17,7 +17,7 @@ class InfectionsSample(NamedTuple): Attributes ---------- - post_initialization_infections : ArrayLike | None, optional + post_initialization_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. """ @@ -97,4 +97,10 @@ def sample( reversed_generation_interval_pmf=gen_int_rev, ) - return InfectionsSample(post_initialization_infections) + return InfectionsSample( + SampledValue( + post_initialization_infections, + t_start=self.t_start, + t_unit=self.t_unit, + ) + ) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 041a1395..fffa7307 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -8,7 +8,11 @@ import pyrenew.arrayutils as au import pyrenew.latent.infection_functions as inf from numpy.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + RandomVariable, + SampledValue, + _assert_sample_and_rtype, +) class InfectionsRtFeedbackSample(NamedTuple): @@ -17,14 +21,14 @@ class InfectionsRtFeedbackSample(NamedTuple): Attributes ---------- - post_initialization_infections : ArrayLike | None, optional + post_initialization_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. - rt : ArrayLike | None, optional + rt : SampledValue | None, optional The adjusted reproduction number. Defaults to None. """ - post_initialization_infections: ArrayLike | None = None - rt: ArrayLike | None = None + post_initialization_infections: SampledValue | None = None + rt: SampledValue | None = None def __repr__(self): return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections}, rt={self.rt})" @@ -156,9 +160,9 @@ def sample( I0 = I0[-gen_int_rev.size :] # Sampling inf feedback strength - inf_feedback_strength, *_ = self.infection_feedback_strength( + inf_feedback_strength = self.infection_feedback_strength( **kwargs, - ) + )[0].value # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: @@ -177,7 +181,7 @@ def sample( # Sampling inf feedback pmf inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value) ( post_initialization_infections, @@ -195,6 +199,10 @@ def sample( numpyro.deterministic("Rt_adjusted", Rt_adj) return InfectionsRtFeedbackSample( - post_initialization_infections=post_initialization_infections, - rt=Rt_adj, + post_initialization_infections=SampledValue( + value=post_initialization_infections, + t_start=self.t_start, + t_unit=self.t_unit, + ), + rt=SampledValue(Rt_adj, t_start=self.t_start, t_unit=self.t_unit), ) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index fd058d24..06df3be5 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -5,7 +5,7 @@ """ from abc import ABCMeta, abstractmethod -from typing import get_type_hints +from typing import NamedTuple, get_type_hints import jax import jax.numpy as jnp @@ -94,6 +94,28 @@ def _assert_sample_and_rtype( return None +class SampledValue(NamedTuple): + """ + A container for a sampled value from a RandomVariable. + + Attributes + ---------- + value : ArrayLike, optional + The sampled value. + t_start : int, optional + The start time of the value. + t_unit : int, optional + The unit of time relative to the model's fundamental (smallest) time unit. + """ + + value: ArrayLike | None = None + t_start: int | None = None + t_unit: int | None = None + + def __repr__(self): + return f"SampledValue(value={self.value}, t_start={self.t_start}, t_unit={self.t_unit})" + + class RandomVariable(metaclass=ABCMeta): """ Abstract base class for latent and observed random variables. @@ -153,6 +175,18 @@ def set_timeseries( ------- None """ + + # Either both values are None or both are not None + assert (t_unit is not None and t_start is not None) or ( + t_unit is None and t_start is None + ), ( + "Both t_start and t_unit should be None or not None. " + "Currently, t_start is {t_start} and t_unit is {t_unit}." + ) + + if t_unit is None and t_start is None: + return None + # Timeseries unit should be a positive integer assert isinstance( t_unit, int @@ -292,7 +326,13 @@ def sample( fn=self.dist, obs=obs, ) - return (jnp.atleast_1d(sample),) + return ( + SampledValue( + jnp.atleast_1d(sample), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) class Model(metaclass=ABCMeta): @@ -643,7 +683,12 @@ def sample(self, **kwargs) -> tuple: untransformed_values = self.base_rv.sample(**kwargs) return tuple( - t(uv) for t, uv in zip(self.transforms, untransformed_values) + SampledValue( + t(uv.value), + t_start=self.t_start, + t_unit=self.t_unit, + ) + for t, uv in zip(self.transforms, untransformed_values) ) def sample_length(self): diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index f87090af..3bf3aa50 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -7,7 +7,12 @@ from jax.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + Model, + RandomVariable, + SampledValue, + _assert_sample_and_rtype, +) from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -17,23 +22,23 @@ class HospModelSample(NamedTuple): Attributes ---------- - Rt : float | None, optional + Rt : SampledValue | None, optional The reproduction number over time. Defaults to None. - latent_infections : ArrayLike | None, optional + latent_infections : SampledValue | None, optional The estimated number of new infections over time. Defaults to None. - infection_hosp_rate : float | None, optional + infection_hosp_rate : SampledValue | None, optional The infected hospitalization rate. Defaults to None. - latent_hosp_admissions : ArrayLike | None, optional + latent_hosp_admissions : SampledValue | None, optional The estimated latent hospitalizations. Defaults to None. - observed_hosp_admissions : ArrayLike | None, optional + observed_hosp_admissions : SampledValue | None, optional The sampled or observed hospital admissions. Defaults to None. """ - Rt: float | None = None - latent_infections: ArrayLike | None = None - infection_hosp_rate: float | None = None - latent_hosp_admissions: ArrayLike | None = None - observed_hosp_admissions: ArrayLike | None = None + Rt: SampledValue | None = None + latent_infections: SampledValue | None = None + infection_hosp_rate: SampledValue | None = None + latent_hosp_admissions: SampledValue | None = None + observed_hosp_admissions: SampledValue | None = None def __repr__(self): return ( @@ -195,7 +200,7 @@ def sample( latent_hosp_admissions, *_, ) = self.latent_hosp_admissions_rv( - latent_infections=basic_model.latent_infections, + latent_infections=basic_model.latent_infections.value, **kwargs, ) @@ -203,7 +208,7 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv( - mu=latent_hosp_admissions[-n_datapoints:], + mu=latent_hosp_admissions.value[-n_datapoints:], obs=data_observed_hosp_admissions, **kwargs, ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 7e04b8d6..e8b46ae5 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -10,7 +10,12 @@ import pyrenew.arrayutils as au from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + Model, + RandomVariable, + SampledValue, + _assert_sample_and_rtype, +) # Output class of the RtInfectionsRenewalModel @@ -20,17 +25,17 @@ class RtInfectionsRenewalSample(NamedTuple): Attributes ---------- - Rt : ArrayLike | None, optional + Rt : SampledValue | None, optional The reproduction number over time. Defaults to None. - latent_infections : ArrayLike | None, optional + latent_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. - observed_infections : ArrayLike | None, optional + observed_infections : SampledValue | None, optional The sampled infections. Defaults to None. """ - Rt: ArrayLike | None = None - latent_infections: ArrayLike | None = None - observed_infections: ArrayLike | None = None + Rt: SampledValue | None = None + latent_infections: SampledValue | None = None + observed_infections: SampledValue | None = None def __repr__(self): return ( @@ -202,33 +207,33 @@ def sample( post_initialization_latent_infections, *_, ) = self.latent_infections_rv( - Rt=Rt, - gen_int=gen_int, - I0=I0, + Rt=Rt.value, + gen_int=gen_int.value, + I0=I0.value, **kwargs, ) observed_infections, *_ = self.infection_obs_process_rv( - mu=post_initialization_latent_infections[padding:], + mu=post_initialization_latent_infections.value[padding:], obs=data_observed_infections, **kwargs, ) all_latent_infections = jnp.hstack( - [I0, post_initialization_latent_infections] + [I0.value, post_initialization_latent_infections.value] ) numpyro.deterministic("all_latent_infections", all_latent_infections) if observed_infections is not None: observed_infections = au.pad_x_to_match_y( - observed_infections, + observed_infections.value, all_latent_infections, jnp.nan, pad_direction="start", ) Rt = au.pad_x_to_match_y( - Rt, + Rt.value, all_latent_infections, jnp.nan, pad_direction="start", @@ -236,7 +241,7 @@ def sample( numpyro.deterministic("Rt", Rt) return RtInfectionsRenewalSample( - Rt=Rt, - latent_infections=all_latent_infections, - observed_infections=observed_infections, + Rt=SampledValue(Rt), + latent_infections=SampledValue(all_latent_infections), + observed_infections=SampledValue(observed_infections), ) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index 1673eba6..cf583021 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -6,7 +6,7 @@ import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class NegativeBinomialObservation(RandomVariable): @@ -93,8 +93,15 @@ def sample( name=self.name, fn=dist.NegativeBinomial2( mean=mu + self.eps, - concentration=concentration, + concentration=concentration.value, ), obs=obs, ) - return (negative_binomial_sample,) + + return ( + SampledValue( + negative_binomial_sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index cef4256f..6744efb7 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -6,7 +6,7 @@ import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class PoissonObservation(RandomVariable): @@ -72,4 +72,10 @@ def sample( fn=dist.Poisson(rate=mu + self.eps), obs=obs, ) - return (poisson_sample,) + return ( + SampledValue( + poisson_sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index bc68c049..8c2203ec 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -8,7 +8,7 @@ import numpyro.distributions as dist from jax import lax from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class ARProcess(RandomVariable): @@ -91,7 +91,13 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08 ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) - return (jnp.hstack([inits, self.mean + ts.flatten()]),) + return ( + SampledValue( + jnp.hstack([inits, self.mean + ts.flatten()]), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index e3e594cd..22cecc4e 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.process import ARProcess @@ -75,7 +75,13 @@ def sample( duration=duration, inits=jnp.atleast_1d(init_rate_of_change), ) - return (init_val + jnp.cumsum(rates_of_change.flatten()),) + return ( + SampledValue( + init_val + jnp.cumsum(rates_of_change.value.flatten()), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) @staticmethod def validate(): diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index 61ca1f28..ffb2f183 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -2,9 +2,12 @@ from typing import NamedTuple -import jax.numpy as jnp import pyrenew.arrayutils as au -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + RandomVariable, + SampledValue, + _assert_sample_and_rtype, +) class PeriodicEffectSample(NamedTuple): @@ -14,11 +17,11 @@ class PeriodicEffectSample(NamedTuple): Attributes ---------- - value: jnp.ndarray + value: SampledValue The sampled value. """ - value: jnp.ndarray + value: SampledValue def __repr__(self): return f"PeriodicEffectSample(value={self.value})" @@ -110,9 +113,13 @@ def sample(self, duration: int, **kwargs): """ return PeriodicEffectSample( - value=self.broadcaster( - data=self.quantity_to_broadcast.sample(**kwargs)[0], - n_timepoints=duration, + value=SampledValue( + self.broadcaster( + data=self.quantity_to_broadcast.sample(**kwargs)[0].value, + n_timepoints=duration, + ), + t_start=self.t_start, + t_unit=self.t_unit, ) ) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index 3491bd32..1fd5da86 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -4,7 +4,11 @@ import jax.numpy as jnp from jax.typing import ArrayLike from pyrenew.arrayutils import PeriodicBroadcaster -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + RandomVariable, + SampledValue, + _assert_sample_and_rtype, +) from pyrenew.process.firstdifferencear import FirstDifferenceARProcess @@ -14,11 +18,11 @@ class RtPeriodicDiffProcessSample(NamedTuple): Attributes ---------- - rt : ArrayLike + rt : SampledValue, optional The sampled Rt. """ - rt: ArrayLike | None = None + rt: SampledValue | None = None def __repr__(self): return f"RtPeriodicDiffProcessSample(rt={self.rt})" @@ -171,9 +175,9 @@ def sample( """ # Initial sample - log_rt_prior = self.log_rt_prior.sample(**kwargs)[0] - b = self.autoreg.sample(**kwargs)[0] - s_r = self.periodic_diff_sd.sample(**kwargs)[0] + log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].value + b = self.autoreg.sample(**kwargs)[0].value + s_r = self.periodic_diff_sd.sample(**kwargs)[0].value # How many periods to sample? n_periods = int(jnp.ceil(duration / self.period_size)) @@ -187,7 +191,11 @@ def sample( )[0] return RtPeriodicDiffProcessSample( - rt=self.broadcaster(jnp.exp(log_rt.flatten()), duration), + rt=SampledValue( + self.broadcaster(jnp.exp(log_rt.value.flatten()), duration), + t_start=self.t_start, + t_unit=self.t_unit, + ), ) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index cc396192..a88ea0d7 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -3,7 +3,7 @@ import jax.numpy as jnp from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, SampledValue class SimpleRandomWalkProcess(RandomVariable): @@ -77,16 +77,22 @@ def sample( def transition(x_prev, _): # numpydoc ignore=GL08 diff, *_ = self.step_rv(**kwargs) - x_curr = x_prev + diff + x_curr = x_prev + diff.value return x_curr, x_curr _, x = scan( transition, - init=init, + init=init.value, xs=jnp.arange(n_steps - 1), ) - return (jnp.hstack([init, x.flatten()]),) + return ( + SampledValue( + jnp.hstack([init.value, x.flatten()]), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) @staticmethod def validate(): diff --git a/model/src/test/test_ar_process.py b/model/src/test/test_ar_process.py index 42a7e92d..e6554dd7 100755 --- a/model/src/test/test_ar_process.py +++ b/model/src/test/test_ar_process.py @@ -39,5 +39,5 @@ def test_ar_samples_correctly_distributed(): # check it regresses to mean # when started away from it long_ts, *_ = ar1(duration=10000, inits=ar_inits) - assert_almost_equal(long_ts[0], ar_inits) - assert jnp.abs(long_ts[-1] - ar_mean) < 4 * noise_sd + assert_almost_equal(long_ts.value[0], ar_inits) + assert jnp.abs(long_ts.value[-1] - ar_mean) < 4 * noise_sd diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index 243eba39..6e72f8cb 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -33,7 +33,7 @@ def test_deterministic(): var5 = NullProcess() testing.assert_array_equal( - var1()[0], + var1()[0].value, jnp.array( [ 1, @@ -41,16 +41,16 @@ def test_deterministic(): ), ) testing.assert_array_equal( - var2()[0], + var2()[0].value, jnp.array([0.25, 0.25, 0.2, 0.3]), ) testing.assert_array_equal( - var3(duration=5)[0], + var3(duration=5)[0].value, jnp.array([1, 2, 3, 4, 4]), ) testing.assert_array_equal( - var3(duration=3)[0], + var3(duration=3)[0].value, jnp.array( [ 1, @@ -60,5 +60,5 @@ def test_deterministic(): ), ) - testing.assert_equal(var4()[0], None) - testing.assert_equal(var5(duration=1)[0], None) + testing.assert_equal(var4()[0].value, None) + testing.assert_equal(var5(duration=1)[0].value, None) diff --git a/model/src/test/test_first_difference_ar.py b/model/src/test/test_first_difference_ar.py index 303df6b7..14eed675 100755 --- a/model/src/test/test_first_difference_ar.py +++ b/model/src/test/test_first_difference_ar.py @@ -26,5 +26,5 @@ def test_fd_ar_can_be_sampled(): ) # Checking proper shape - assert ans0[0].shape == (3532,) - assert ans1[0].shape == (3532,) + assert ans0[0].value.shape == (3532,) + assert ans1[0].value.shape == (3532,) diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 5de2c1a4..f6b736a0 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -60,7 +60,7 @@ def test_forecast(): model.run( num_warmup=5, num_samples=5, - data_observed_infections=model_sample.observed_infections, + data_observed_infections=model_sample.observed_infections.value, rng_key=jr.key(54), ) diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index 7eb40a20..44edd3c5 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -20,6 +20,8 @@ def test_initialize_infections_exponential(): (I_pre_init,) = I_pre_init_RV() (rate,) = rate_RV() + I_pre_init = I_pre_init.value + rate = rate.value infections_default_t_pre_init = InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV ).initialize_infections(I_pre_init) @@ -52,7 +54,7 @@ def test_initialize_infections_exponential(): with pytest.raises(ValueError): InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV - ).initialize_infections(I_pre_init_2) + ).initialize_infections(I_pre_init_2.value) # test non-default t_pre_init t_pre_init = 6 @@ -77,6 +79,7 @@ def test_initialize_infections_zero_pad(): n_timepoints = 10 I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", value=10.0) (I_pre_init,) = I_pre_init_RV() + I_pre_init = I_pre_init.value infections = InitializeInfectionsZeroPad( n_timepoints @@ -88,7 +91,9 @@ def test_initialize_infections_zero_pad(): I_pre_init_RV_2 = DeterministicVariable( name="I_pre_init_RV", value=np.array([10.0, 10.0]) ) + (I_pre_init_2,) = I_pre_init_RV_2() + I_pre_init_2 = I_pre_init_2.value infections_2 = InitializeInfectionsZeroPad( n_timepoints diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index bd15f5cd..46032e61 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -95,10 +95,10 @@ def test_infectionsrtfeedback(): ) assert_array_equal( - samp1.post_initialization_infections, - samp2.post_initialization_infections, + samp1.post_initialization_infections.value, + samp2.post_initialization_infections.value, ) - assert_array_equal(samp1.rt, Rt) + assert_array_equal(samp1.rt.value, Rt) return None @@ -142,18 +142,18 @@ def test_infectionsrtfeedback_feedback(): gen_int=gen_int, Rt=Rt, I0=I0, - inf_feedback_strength=inf_feed_strength()[0], - inf_feedback_pmf=inf_feedback_pmf()[0], + inf_feedback_strength=inf_feed_strength()[0].value, + inf_feedback_pmf=inf_feedback_pmf()[0].value, ) assert not jnp.array_equal( - samp1.post_initialization_infections, - samp2.post_initialization_infections, + samp1.post_initialization_infections.value, + samp2.post_initialization_infections.value, ) assert_array_almost_equal( - samp1.post_initialization_infections, + samp1.post_initialization_infections.value, res["post_initialization_infections"], ) - assert_array_almost_equal(samp1.rt, res["rt"]) + assert_array_almost_equal(samp1.rt.value, res["rt"]) return None diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index ca49acb6..cf841a8d 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -35,7 +35,7 @@ def test_admissions_sample(): ) with numpyro.handlers.seed(rng_seed=223): - sim_rt, *_ = rt(n_steps=30) + sim_rt = rt(n_steps=30)[0].value gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) i0 = 10 * jnp.ones_like(gen_int) @@ -80,9 +80,9 @@ def test_admissions_sample(): ) with numpyro.handlers.seed(rng_seed=223): - sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0]) + sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0].value) testing.assert_array_less( - sim_hosp_1.latent_hospital_admissions, - inf_sampled1[0], + sim_hosp_1.latent_hospital_admissions.value, + inf_sampled1[0].value, ) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index d472a1d7..d55c7dff 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -40,7 +40,7 @@ def test_infections_as_deterministic(): inf1 = Infections() obs = dict( - Rt=sim_rt, + Rt=sim_rt.value, I0=jnp.zeros(gen_int.size), gen_int=gen_int, ) @@ -49,8 +49,8 @@ def test_infections_as_deterministic(): inf_sampled2 = inf1(**obs) testing.assert_array_equal( - inf_sampled1.post_initialization_infections, - inf_sampled2.post_initialization_infections, + inf_sampled1.post_initialization_infections.value, + inf_sampled2.post_initialization_infections.value, ) # Check that Initial infections vector must be at least as long as the generation interval. diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 83702b47..c44737c9 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -160,20 +160,21 @@ def test_model_basicrenewal_no_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model0.sample(n_datapoints=30) - np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt) + np.testing.assert_array_equal(model0_samp.Rt.value, model1_samp.Rt.value) np.testing.assert_array_equal( - model0_samp.latent_infections, model1_samp.latent_infections + model0_samp.latent_infections.value, + model1_samp.latent_infections.value, ) np.testing.assert_array_equal( - model0_samp.observed_infections, - model1_samp.observed_infections, + model0_samp.observed_infections.value, + model1_samp.observed_infections.value, ) model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_infections=model0_samp.latent_infections, + data_observed_infections=model0_samp.latent_infections.value, ) inf = model0.spread_draws(["all_latent_infections"]) @@ -227,7 +228,7 @@ def test_model_basicrenewal_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections, + data_observed_infections=model1_samp.observed_infections.value, ) inf = model1.spread_draws(["all_latent_infections"]) @@ -277,7 +278,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections, + data_observed_infections=model1_samp.observed_infections.value, padding=5, ) diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 2b8d113e..a573ec06 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -24,6 +24,7 @@ from pyrenew.metaclass import ( DistributionalRV, RandomVariable, + SampledValue, TransformedRandomVariable, ) from pyrenew.model import HospitalAdmissionsModel @@ -69,8 +70,10 @@ def validate(self): # numpydoc ignore=GL08 def sample(self, **kwargs): # numpydoc ignore=GL08 return ( - numpyro.sample( - name=self.name, fn=dist.Uniform(high=0.99, low=0.01) + SampledValue( + numpyro.sample( + name=self.name, fn=dist.Uniform(high=0.99, low=0.01) + ) ), ) @@ -280,26 +283,31 @@ def test_model_hosp_no_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model0.sample(n_datapoints=30) - np.testing.assert_array_almost_equal(model0_samp.Rt, model1_samp.Rt) - np.testing.assert_array_equal( - model0_samp.latent_infections, model1_samp.latent_infections + np.testing.assert_array_almost_equal( + model0_samp.Rt.value, model1_samp.Rt.value ) np.testing.assert_array_equal( - model0_samp.infection_hosp_rate, model1_samp.infection_hosp_rate + model0_samp.latent_infections.value, + model1_samp.latent_infections.value, ) np.testing.assert_array_equal( - model0_samp.latent_hosp_admissions, model1_samp.latent_hosp_admissions + model0_samp.infection_hosp_rate.value, + model1_samp.infection_hosp_rate.value, ) np.testing.assert_array_equal( - model0_samp.observed_hosp_admissions, - model1_samp.observed_hosp_admissions, + model0_samp.latent_hosp_admissions.value, + model1_samp.latent_hosp_admissions.value, ) + # These are supposed to be none, both + assert model0_samp.observed_hosp_admissions.value is None + assert model1_samp.observed_hosp_admissions.value is None + model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model0_samp.latent_hosp_admissions, + data_observed_hosp_admissions=model0_samp.latent_hosp_admissions.value, ) inf = model0.spread_draws(["latent_hospital_admissions"]) @@ -385,7 +393,7 @@ def test_model_hosp_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -482,7 +490,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -591,12 +599,20 @@ def test_model_hosp_with_obs_model_weekday_phosp(): n_datapoints=n_obs_to_generate, padding=pad_size ) + # Showed during merge conflict, but unsure if it will be needed + # pad_size = 5 + # obs = jnp.hstack( + # [ + # jnp.repeat(jnp.nan, pad_size), + # model1_samp.observed_hosp_admissions.value[pad_size:], + # ] + # ) # Running with padding model1.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, padding=pad_size, ) diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index 80507c7b..b369e1ed 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -26,12 +26,12 @@ def test_negativebinom_deterministic_obs(): assert isinstance(sim_nb1, tuple) assert isinstance(sim_nb2, tuple) - assert isinstance(sim_nb1[0], ArrayLike) - assert isinstance(sim_nb2[0], ArrayLike) + assert isinstance(sim_nb1[0].value, ArrayLike) + assert isinstance(sim_nb2[0].value, ArrayLike) testing.assert_array_equal( - sim_nb1[0], - sim_nb2[0], + sim_nb1[0].value, + sim_nb2[0].value, ) @@ -51,11 +51,11 @@ def test_negativebinom_random_obs(): sim_nb2 = negb(mu=rates) assert isinstance(sim_nb1, tuple) assert isinstance(sim_nb2, tuple) - assert isinstance(sim_nb1[0], ArrayLike) - assert isinstance(sim_nb2[0], ArrayLike) + assert isinstance(sim_nb1[0].value, ArrayLike) + assert isinstance(sim_nb2[0].value, ArrayLike) testing.assert_array_almost_equal( - np.mean(sim_nb1[0]), - np.mean(sim_nb2[0]), + np.mean(sim_nb1[0].value), + np.mean(sim_nb2[0].value), decimal=1, ) diff --git a/model/src/test/test_observation_poisson.py b/model/src/test/test_observation_poisson.py index e1844b11..10e37f87 100644 --- a/model/src/test/test_observation_poisson.py +++ b/model/src/test/test_observation_poisson.py @@ -19,4 +19,4 @@ def test_poisson_obs(): with numpyro.handlers.seed(rng_seed=223): sim_pois, *_ = pois(mu=rates) - testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois)) + testing.assert_array_equal(sim_pois.value, jnp.ceil(sim_pois.value)) diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index 85fbcfb8..7173efa4 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -28,7 +28,7 @@ def test_periodiceffect() -> None: pe = PeriodicEffect(**params) with numpyro.handlers.seed(rng_seed=223): - ans = pe(duration=duration).value + ans = pe(duration=duration)[0].value # Checking that the shape of the sampled Rt is correct assert ans.shape == (duration,) @@ -42,9 +42,9 @@ def test_periodiceffect() -> None: params["offset"] = 5 pe = PeriodicEffect(**params) with numpyro.handlers.seed(rng_seed=223): - ans2 = pe(duration=duration).value + ans2 = pe(duration=duration)[0].value - # Checking that the shape of the sampled Rt is correct + ans2 = pe(duration=duration)[0].value assert ans2.shape == (duration,) # This time series should be the same as the previous one, but shifted by @@ -79,8 +79,8 @@ def test_weeklyeffect() -> None: pe = PeriodicEffect(**params) pe2 = DayOfWeekEffect(**params2) - ans1 = pe(duration=duration).value - ans2 = pe2(duration=duration).value + ans1 = pe(duration=duration)[0].value + ans2 = pe2(duration=duration)[0].value assert_array_equal(ans1, ans2) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index 81565f8e..d032bb93 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -108,7 +108,7 @@ def test_rng_keys_produce_correct_samples(): # as the observed_infections for the rest of the models with numpyro.handlers.seed(rng_seed=223): model_sample = models[0].sample(n_datapoints=n_datapoints[0]) - obs_infections = [model_sample.observed_infections] * len(models) + obs_infections = [model_sample.observed_infections.value] * len(models) rng_keys = [jr.key(54), jr.key(54), None, None, jr.key(74)] # run test models with the different keys diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 76ea3c73..242e0400 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -40,12 +40,12 @@ def test_rw_can_be_sampled(): ans_fixed = rw_init_fixed(n_steps=5023) # check that the samples are of the right shape - assert ans_rand[0].shape == (3532,) - assert ans_fixed[0].shape == (5023,) + assert ans_rand[0].value.shape == (3532,) + assert ans_fixed[0].value.shape == (5023,) # check that fixing inits works - assert_almost_equal(ans_fixed[0][0], init_rv_fixed.value) - assert ans_rand[0][0] != init_rv_fixed.value + assert_almost_equal(ans_fixed[0].value[0], init_rv_fixed.value) + assert ans_rand[0].value[0] != init_rv_fixed.value def test_rw_samples_correctly_distributed(): @@ -72,6 +72,7 @@ def test_rw_samples_correctly_distributed(): with numpyro.handlers.seed(rng_seed=62): samples, *_ = rw_normal(n_steps=n_samples) + samples = samples.value # Checking the shape assert samples.shape == (n_samples,) diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index be6e0d0a..4fb2cbfb 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -66,8 +66,8 @@ def test_rtweeklydiff() -> None: rtwd = RtWeeklyDiffProcess(**params) - with numpyro.handlers.seed(rng_seed=121): - rt = rtwd(duration=duration).rt + with numpyro.handlers.seed(rng_seed=223): + rt = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -80,14 +80,15 @@ def test_rtweeklydiff() -> None: # Checking start off a different day of the week params["offset"] = 5 rtwd = RtWeeklyDiffProcess(**params) - with numpyro.handlers.seed(rng_seed=121): - rt2 = rtwd(duration=duration).rt + + with numpyro.handlers.seed(rng_seed=223): + rt2 = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct assert rt2.shape == (duration,) - # This time series should be the same as the previous one, but shifted by - # 5 days + # This time series should be the same as the previous one, + # but shifted by 5 days assert_array_equal(rt[5:], rt2[:-5]) return None @@ -115,8 +116,9 @@ def test_rtweeklydiff_no_autoregressive() -> None: rtwd = RtWeeklyDiffProcess(**params) duration = 1000 + with numpyro.handlers.seed(rng_seed=323): - rt = rtwd(duration=duration).rt + rt = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -159,12 +161,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None: _, ans0 = lax.scan( f=rtwd.autoreg_process, - init=np.hstack([params["log_rt_prior"]()[0], b]), + init=np.hstack([params["log_rt_prior"]()[0].value, b]), xs=noise, ) ans1 = _manual_rt_weekly_diff( - log_seed=params["log_rt_prior"]()[0], sd=noise, b=b + log_seed=params["log_rt_prior"]()[0].value, sd=noise, b=b ) assert_array_almost_equal(ans0, ans1) @@ -194,7 +196,7 @@ def test_rtperiodicdiff_smallsample(): rtwd = RtWeeklyDiffProcess(**params) with numpyro.handlers.seed(rng_seed=223): - rt = rtwd(duration=6).rt + rt = rtwd(duration=6).rt.value # Checking that the shape of the sampled Rt is correct assert rt.shape == (6,) diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py index 134a9180..210041a5 100644 --- a/model/src/test/test_transformed_rv_class.py +++ b/model/src/test/test_transformed_rv_class.py @@ -12,6 +12,7 @@ from pyrenew.metaclass import ( DistributionalRV, RandomVariable, + SampledValue, TransformedRandomVariable, ) @@ -31,9 +32,12 @@ def sample(self, **kwargs): Returns ------- tuple - (1, 5) + (SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), SampledValue(5, t_start=self.t_start, t_unit=self.t_unit)) """ - return (1, 5) + return ( + SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), + SampledValue(5, t_start=self.t_start, t_unit=self.t_unit), + ) def sample_length(self): """ @@ -130,9 +134,12 @@ def test_transforms_applied_at_sampling(): l2_transformed_sample = tr_l2.sample() assert_almost_equal( - (tr(norm_base_sample[0]),), norm_transformed_sample + tr(norm_base_sample[0].value), norm_transformed_sample[0].value ) assert_almost_equal( - (tr(l2_base_sample[0]), t.ExpTransform()(l2_base_sample[1])), - l2_transformed_sample, + ( + tr(l2_base_sample[0].value), + t.ExpTransform()(l2_base_sample[1].value), + ), + (l2_transformed_sample[0].value, l2_transformed_sample[1].value), )