diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 116bc75f..e79cb833 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -25,6 +25,7 @@ from pyrenew import process, deterministic ```{python} # The random process for Rt rt_proc = process.RtWeeklyDiffProcess( + name="rt_weekly_diff", offset=0, log_rt_prior=deterministic.DeterministicVariable( jnp.array([0.1, 0.2]), name="log_rt_prior" diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index cd602212..c10a9477 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -26,7 +26,7 @@ def __init__( ---------- vars : ArrayLike A tuple with arraylike objects. - name : str, optional + name : str A name to assign to the process. Returns diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index 435c68b6..f25e1929 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -131,7 +131,6 @@ def sample( self, mu: ArrayLike, obs: ArrayLike | None = None, - name: str | None = None, **kwargs, ) -> tuple: """ @@ -143,8 +142,6 @@ def sample( Unused parameter, represents mean of non-null distributions obs : ArrayLike, optional Observed data. Defaults to None. - name : str, optional - Name of the random variable. Defaults to None. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index aa059bfe..8cb15270 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -65,7 +65,6 @@ def __init__( self, infection_to_admission_interval_rv: RandomVariable, infect_hosp_rate_rv: RandomVariable, - latent_hospital_admissions_varname: str = "latent_hospital_admissions", day_of_week_effect_rv: RandomVariable | None = None, hosp_report_prob_rv: RandomVariable | None = None, ) -> None: @@ -79,9 +78,6 @@ def __init__( pyrenew.observations.Deterministic). infect_hosp_rate_rv : RandomVariable Infection to hospitalization rate random variable. - latent_hospital_admissions_varname : str - Name to assign to the deterministic component in numpyro of - observed hospital admissions. day_of_week_effect_rv : RandomVariable, optional Day of the week effect. hosp_report_prob_rv : RandomVariable, optional @@ -104,10 +100,6 @@ def __init__( hosp_report_prob_rv, ) - self.latent_hospital_admissions_varname = ( - latent_hospital_admissions_varname - ) - self.infect_hosp_rate_rv = infect_hosp_rate_rv self.day_of_week_effect_rv = day_of_week_effect_rv self.hosp_report_prob_rv = hosp_report_prob_rv @@ -200,7 +192,7 @@ def sample( ) npro.deterministic( - self.latent_hospital_admissions_varname, latent_hospital_admissions + "latent_hospital_admissions", latent_hospital_admissions ) return HospitalAdmissionsSample( diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index 4f378197..cef4256f 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -24,7 +24,7 @@ def __init__( Parameters ---------- - name : str, optional + name : str Passed to numpyro.sample. eps : float, optional Small value added to the rate parameter to avoid zero values. diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index 9b7e7e31..5955f34a 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -20,6 +20,7 @@ class ARProcess(RandomVariable): def __init__( self, + name: str, mean: float, autoreg: ArrayLike, noise_sd: float, @@ -29,6 +30,8 @@ def __init__( Parameters ---------- + name : str + Name of the parameter passed to numpyro.sample. mean: float Mean parameter. autoreg : ArrayLike @@ -40,6 +43,7 @@ def __init__( ------- None """ + self.name = name self.mean = mean self.autoreg = autoreg self.noise_sd = noise_sd @@ -48,7 +52,6 @@ def sample( self, duration: int, inits: ArrayLike = None, - name: str = "arprocess", **kwargs, ) -> tuple: """ @@ -61,9 +64,6 @@ def sample( inits : ArrayLike, optional Initial points, if None, then these are sampled. Defaults to None. - name : str, optional - Name of the parameter passed to numpyro.sample. - Defaults to "arprocess". **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. @@ -76,7 +76,7 @@ def sample( order = self.autoreg.shape[0] if inits is None: inits = numpyro.sample( - name + "_sampled_inits", + self.name + "_sampled_inits", dist.Normal(0, self.noise_sd).expand((order,)), ) @@ -86,7 +86,7 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08 return new_carry, new_term noise = numpyro.sample( - name + "_noise", + self.name + "_noise", dist.Normal(0, self.noise_sd).expand((duration - inits.size,)), ) diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index 99d1e38e..4e429afd 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -18,6 +18,7 @@ class FirstDifferenceARProcess(RandomVariable): def __init__( self, + name: str, autoreg: ArrayLike, noise_sd: float, ) -> None: @@ -26,6 +27,8 @@ def __init__( Parameters ---------- + name : str + Passed to ARProcess() autoreg : ArrayLike Process parameters pyrenew.processesARprocess. noise_sd : float @@ -35,14 +38,16 @@ def __init__( ------- None """ - self.rate_of_change_proc = ARProcess(0, jnp.array([autoreg]), noise_sd) + self.rate_of_change_proc = ARProcess( + "arprocess", 0, jnp.array([autoreg]), noise_sd + ) + self.name = name def sample( self, duration: int, init_val: ArrayLike = None, init_rate_of_change: ArrayLike = None, - name: str = "trend_rw", **kwargs, ) -> tuple: """ @@ -56,8 +61,6 @@ def sample( Starting point of the AR process, by default None. init_rate_of_change : ArrayLike, optional Passed to ARProcess.sample, by default None. - name : str, optional - Passed to ARProcess(), by default "trend_rw" **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. @@ -70,7 +73,7 @@ def sample( rates_of_change, *_ = self.rate_of_change_proc.sample( duration=duration, inits=jnp.atleast_1d(init_rate_of_change), - name=name + "_rate_of_change", + name=self.name + "_rate_of_change", ) return (init_val + jnp.cumsum(rates_of_change.flatten()),) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index fccbf80d..3491bd32 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -45,18 +45,20 @@ class RtPeriodicDiffProcess(RandomVariable): def __init__( self, + name: str, offset: int, period_size: int, log_rt_prior: RandomVariable, autoreg: RandomVariable, periodic_diff_sd: RandomVariable, - site_name: str = "rt_periodic_diff", ) -> None: """ Default constructor for RtPeriodicDiffProcess class. Parameters ---------- + name : str + Name of the site. offset : int Relative point at which data starts, must be between 0 and period_size - 1. @@ -66,14 +68,12 @@ def __init__( Autoregressive parameter. periodic_diff_sd : RandomVariable Standard deviation of the noise. - site_name : str, optional - Name of the site. Defaults to "rt_periodic_diff". Returns ------- None """ - + self.name = name self.broadcaster = PeriodicBroadcaster( offset=offset, period_size=period_size, @@ -91,7 +91,6 @@ def __init__( self.log_rt_prior = log_rt_prior self.autoreg = autoreg self.periodic_diff_sd = periodic_diff_sd - self.site_name = site_name return None @@ -180,7 +179,7 @@ def sample( n_periods = int(jnp.ceil(duration / self.period_size)) # Running the process - ar_diff = FirstDifferenceARProcess(autoreg=b, noise_sd=s_r) + ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r) log_rt = ar_diff.sample( duration=n_periods, init_val=log_rt_prior[1], @@ -199,17 +198,19 @@ class RtWeeklyDiffProcess(RtPeriodicDiffProcess): def __init__( self, + name: str, offset: int, log_rt_prior: RandomVariable, autoreg: RandomVariable, periodic_diff_sd: RandomVariable, - site_name: str = "rt_weekly_diff", ) -> None: """ Default constructor for RtWeeklyDiffProcess class. Parameters ---------- + name : str + Name of the site. offset : int Relative point at which data starts, must be between 0 and 6. log_rt_prior : RandomVariable @@ -218,8 +219,6 @@ def __init__( Autoregressive parameter. periodic_diff_sd : RandomVariable Standard deviation of the noise. - site_name : str, optional - Name of the site. Defaults to "rt_weekly_diff". Returns ------- @@ -227,12 +226,12 @@ def __init__( """ super().__init__( + name=name, offset=offset, period_size=7, log_rt_prior=log_rt_prior, autoreg=autoreg, periodic_diff_sd=periodic_diff_sd, - site_name=site_name, ) return None diff --git a/model/src/test/test_ar_process.py b/model/src/test/test_ar_process.py index fafaaa1f..42a7e92d 100755 --- a/model/src/test/test_ar_process.py +++ b/model/src/test/test_ar_process.py @@ -11,13 +11,15 @@ def test_ar_can_be_sampled(): Check that an AR process can be initialized and sampled from """ - ar1 = ARProcess(5, jnp.array([0.95]), jnp.array([0.5])) + ar1 = ARProcess("arprocess", 5, jnp.array([0.95]), jnp.array([0.5])) with numpyro.handlers.seed(rng_seed=62): # can sample with and without inits ar1(duration=3532, inits=jnp.array([50.0])) ar1(duration=5023) - ar3 = ARProcess(5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5])) + ar3 = ARProcess( + "arprocess", 5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5]) + ) with numpyro.handlers.seed(rng_seed=62): # can sample with and without inits ar3(duration=1230) @@ -32,7 +34,7 @@ def test_ar_samples_correctly_distributed(): ar_mean = 5 noise_sd = jnp.array([0.5]) ar_inits = jnp.array([25.0]) - ar1 = ARProcess(ar_mean, jnp.array([0.75]), noise_sd) + ar1 = ARProcess("arprocess", ar_mean, jnp.array([0.75]), noise_sd) with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it diff --git a/model/src/test/test_first_difference_ar.py b/model/src/test/test_first_difference_ar.py index 16236054..303df6b7 100755 --- a/model/src/test/test_first_difference_ar.py +++ b/model/src/test/test_first_difference_ar.py @@ -13,7 +13,7 @@ def test_fd_ar_can_be_sampled(): can be initialized and sampled from """ - ar_fd = FirstDifferenceARProcess(0.5, 0.5) + ar_fd = FirstDifferenceARProcess("trend_rw", 0.5, 0.5) with numpyro.handlers.seed(rng_seed=62): # can sample with and without inits diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 1abb2a3b..1dc54500 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -253,7 +253,6 @@ def test_model_hosp_no_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - latent_hospital_admissions_varname="latent_hospital_admissions", infect_hosp_rate_rv=DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" ), diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 4da8173a..08cb85f2 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -50,6 +50,7 @@ def test_rtweeklydiff() -> None: """Checks basic functionality of the process""" params = { + "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( jnp.array([0.1, 0.2]), name="log_rt_prior" @@ -58,7 +59,6 @@ def test_rtweeklydiff() -> None: "periodic_diff_sd": DeterministicVariable( jnp.array([0.1]), name="periodic_diff_sd" ), - "site_name": "test", } duration = 30 @@ -97,6 +97,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: """Checks step size averages close to 0""" params = { + "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( jnp.array([0.0, 0.0]), name="log_rt_prior" @@ -106,7 +107,6 @@ def test_rtweeklydiff_no_autoregressive() -> None: "periodic_diff_sd": DeterministicVariable( jnp.array([0.1]), name="periodic_diff_sd" ), - "site_name": "test", } rtwd = RtWeeklyDiffProcess(**params) @@ -135,6 +135,7 @@ def test_rtweeklydiff_manual_reconstruction() -> None: """Checks that the 'manual' reconstruction is correct""" params = { + "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( jnp.array([0.1, 0.2]), name="log_rt_prior" @@ -143,7 +144,6 @@ def test_rtweeklydiff_manual_reconstruction() -> None: "periodic_diff_sd": DeterministicVariable( jnp.array([0.1]), name="periodic_diff_sd" ), - "site_name": "test", } rtwd = RtWeeklyDiffProcess(**params) @@ -170,6 +170,7 @@ def test_rtperiodicdiff_smallsample(): """Checks basic functionality of the process with a small sample size.""" params = { + "name": "test", "offset": 0, "log_rt_prior": DeterministicVariable( jnp.array([0.1, 0.2]), name="log_rt_prior" @@ -178,7 +179,6 @@ def test_rtperiodicdiff_smallsample(): "periodic_diff_sd": DeterministicVariable( jnp.array([0.1]), name="periodic_diff_sd" ), - "site_name": "test", } rtwd = RtWeeklyDiffProcess(**params)