Skip to content

Commit

Permalink
Rename random variables in rtperiodicdiff (#339)
Browse files Browse the repository at this point in the history
damonbayer authored Jul 30, 2024
1 parent ff33e31 commit f77579e
Showing 13 changed files with 75 additions and 79 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
@@ -152,7 +152,7 @@ class MyRt(RandomVariable):
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
name="init_log_Rt_rv",
name="init_log_rt",
dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
),
),
4 changes: 1 addition & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
@@ -69,9 +69,7 @@ rt = TransformedRandomVariable(
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
2 changes: 1 addition & 1 deletion docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
@@ -202,7 +202,7 @@ class MyRt(metaclass.RandomVariable):
name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=transformation.ExpTransform(),
8 changes: 4 additions & 4 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
@@ -27,13 +27,13 @@ from pyrenew import process, deterministic
rt_proc = process.RtWeeklyDiffProcess(
name="rt_weekly_diff",
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
name="log_rt_prior", value=jnp.array([0.1, 0.2])
log_rt_rv=deterministic.DeterministicVariable(
name="log_rt", value=jnp.array([0.1, 0.2])
),
autoreg=deterministic.DeterministicVariable(
autoreg_rv=deterministic.DeterministicVariable(
name="autoreg", value=jnp.array([0.7])
),
periodic_diff_sd=deterministic.DeterministicVariable(
periodic_diff_sd_rv=deterministic.DeterministicVariable(
name="periodic_diff_sd", value=jnp.array([0.1])
),
)
70 changes: 35 additions & 35 deletions model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
@@ -52,9 +52,9 @@ def __init__(
name: str,
offset: int,
period_size: int,
log_rt_prior: RandomVariable,
autoreg: RandomVariable,
periodic_diff_sd: RandomVariable,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
) -> None:
"""
Default constructor for RtPeriodicDiffProcess class.
@@ -66,11 +66,11 @@ def __init__(
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
log_rt_prior : RandomVariable
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg : RandomVariable
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd : RandomVariable
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.
Returns
@@ -85,45 +85,45 @@ def __init__(
)

self.validate(
log_rt_prior=log_rt_prior,
autoreg=autoreg,
periodic_diff_sd=periodic_diff_sd,
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

self.period_size = period_size
self.offset = offset
self.log_rt_prior = log_rt_prior
self.autoreg = autoreg
self.periodic_diff_sd = periodic_diff_sd
self.log_rt_rv = log_rt_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv

return None

@staticmethod
def validate(
log_rt_prior: any,
autoreg: any,
periodic_diff_sd: any,
log_rt_rv: any,
autoreg_rv: any,
periodic_diff_sd_rv: any,
) -> None:
"""
Validate the input parameters.
Parameters
----------
log_rt_prior : any
log_rt_rv : any
Log Rt prior for the first two observations.
autoreg : any
autoreg_rv : any
Autoregressive parameter.
periodic_diff_sd : any
periodic_diff_sd_rv : any
Standard deviation of the noise.
Returns
-------
None
"""

_assert_sample_and_rtype(log_rt_prior)
_assert_sample_and_rtype(autoreg)
_assert_sample_and_rtype(periodic_diff_sd)
_assert_sample_and_rtype(log_rt_rv)
_assert_sample_and_rtype(autoreg_rv)
_assert_sample_and_rtype(periodic_diff_sd_rv)

return None

@@ -175,9 +175,9 @@ def sample(
"""

# Initial sample
log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].value
b = self.autoreg.sample(**kwargs)[0].value
s_r = self.periodic_diff_sd.sample(**kwargs)[0].value
log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value
b = self.autoreg_rv.sample(**kwargs)[0].value
s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value

# How many periods to sample?
n_periods = int(jnp.ceil(duration / self.period_size))
@@ -186,8 +186,8 @@ def sample(
ar_diff = FirstDifferenceARProcess("trend_rw", autoreg=b, noise_sd=s_r)
log_rt = ar_diff.sample(
duration=n_periods,
init_val=log_rt_prior[1],
init_rate_of_change=log_rt_prior[1] - log_rt_prior[0],
init_val=log_rt_rv[1],
init_rate_of_change=log_rt_rv[1] - log_rt_rv[0],
)[0]

return RtPeriodicDiffProcessSample(
@@ -208,9 +208,9 @@ def __init__(
self,
name: str,
offset: int,
log_rt_prior: RandomVariable,
autoreg: RandomVariable,
periodic_diff_sd: RandomVariable,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
) -> None:
"""
Default constructor for RtWeeklyDiffProcess class.
@@ -221,11 +221,11 @@ def __init__(
Name of the site.
offset : int
Relative point at which data starts, must be between 0 and 6.
log_rt_prior : RandomVariable
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg : RandomVariable
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd : RandomVariable
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.
Returns
@@ -237,9 +237,9 @@ def __init__(
name=name,
offset=offset,
period_size=7,
log_rt_prior=log_rt_prior,
autoreg=autoreg,
periodic_diff_sd=periodic_diff_sd,
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

return None
2 changes: 1 addition & 1 deletion model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ def test_forecast():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
2 changes: 1 addition & 1 deletion model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ def test_admissions_sample():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
2 changes: 1 addition & 1 deletion model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ def test_infections_as_deterministic():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
2 changes: 1 addition & 1 deletion model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ def get_default_rt():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
2 changes: 1 addition & 1 deletion model/src/test/test_model_hosp_admissions.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ def get_default_rt():
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
4 changes: 1 addition & 3 deletions model/src/test/test_predictive.py
Original file line number Diff line number Diff line change
@@ -36,9 +36,7 @@
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
2 changes: 1 addition & 1 deletion model/src/test/test_random_key.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ def create_test_model(): # numpydoc ignore=GL08
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
52 changes: 26 additions & 26 deletions model/src/test/test_rtperiodicdiff.py
Original file line number Diff line number Diff line change
@@ -52,14 +52,14 @@ def test_rtweeklydiff() -> None:
params = {
"name": "test",
"offset": 0,
"log_rt_prior": DeterministicVariable(
name="log_rt_prior", value=jnp.array([0.1, 0.2])
"log_rt_rv": DeterministicVariable(
name="log_rt", value=jnp.array([0.1, 0.2])
),
"autoreg": DeterministicVariable(
name="autoreg", value=jnp.array([0.7])
"autoreg_rv": DeterministicVariable(
name="autoreg_rv", value=jnp.array([0.7])
),
"periodic_diff_sd": DeterministicVariable(
name="periodic_diff_sd", value=jnp.array([0.1])
"periodic_diff_sd_rv": DeterministicVariable(
name="periodic_diff_sd_rv", value=jnp.array([0.1])
),
}
duration = 30
@@ -100,15 +100,15 @@ def test_rtweeklydiff_no_autoregressive() -> None:
params = {
"name": "test",
"offset": 0,
"log_rt_prior": DeterministicVariable(
name="log_rt_prior", value=jnp.array([0.0, 0.0])
"log_rt_rv": DeterministicVariable(
name="log_rt", value=jnp.array([0.0, 0.0])
),
# No autoregression!
"autoreg": DeterministicVariable(
name="autoreg", value=jnp.array([0.0])
"autoreg_rv": DeterministicVariable(
name="autoreg_rv", value=jnp.array([0.0])
),
"periodic_diff_sd": DeterministicVariable(
name="periodic_diff_sd",
"periodic_diff_sd_rv": DeterministicVariable(
name="periodic_diff_sd_rv",
value=jnp.array([0.1]),
),
}
@@ -141,15 +141,15 @@ def test_rtweeklydiff_manual_reconstruction() -> None:
params = {
"name": "test",
"offset": 0,
"log_rt_prior": DeterministicVariable(
name="log_rt_prior",
"log_rt_rv": DeterministicVariable(
name="log_rt",
value=jnp.array([0.1, 0.2]),
),
"autoreg": DeterministicVariable(
name="autoreg", value=jnp.array([0.7])
"autoreg_rv": DeterministicVariable(
name="autoreg_rv", value=jnp.array([0.7])
),
"periodic_diff_sd": DeterministicVariable(
name="periodic_diff_sd",
"periodic_diff_sd_rv": DeterministicVariable(
name="periodic_diff_sd_rv",
value=jnp.array([0.1]),
),
}
@@ -161,12 +161,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None:

_, ans0 = lax.scan(
f=rtwd.autoreg_process,
init=np.hstack([params["log_rt_prior"]()[0].value, b]),
init=np.hstack([params["log_rt_rv"]()[0].value, b]),
xs=noise,
)

ans1 = _manual_rt_weekly_diff(
log_seed=params["log_rt_prior"]()[0].value, sd=noise, b=b
log_seed=params["log_rt_rv"]()[0].value, sd=noise, b=b
)

assert_array_almost_equal(ans0, ans1)
@@ -180,15 +180,15 @@ def test_rtperiodicdiff_smallsample():
params = {
"name": "test",
"offset": 0,
"log_rt_prior": DeterministicVariable(
name="log_rt_prior",
"log_rt_rv": DeterministicVariable(
name="log_rt",
value=jnp.array([0.1, 0.2]),
),
"autoreg": DeterministicVariable(
name="autoreg", value=jnp.array([0.7])
"autoreg_rv": DeterministicVariable(
name="autoreg_rv", value=jnp.array([0.7])
),
"periodic_diff_sd": DeterministicVariable(
name="periodic_diff_sd",
"periodic_diff_sd_rv": DeterministicVariable(
name="periodic_diff_sd_rv",
value=jnp.array([0.1]),
),
}

0 comments on commit f77579e

Please sign in to comment.