Skip to content

Commit

Permalink
Rename data variables in models for clarity (#195)
Browse files Browse the repository at this point in the history
* renaming variables for clarity in basic renewal model

* renaming arguments in hospitalization model for clarity

* fix tutorials
  • Loading branch information
damonbayer authored Jun 14, 2024
1 parent 38003b4 commit 657a206
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 74 deletions.
8 changes: 4 additions & 4 deletions model/docs/example_with_datasets.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.sampled_observed_hosp_admissions)
axs[1].plot(sim_data.observed_hosp_admissions)
axs[1].set_ylabel("Infections")
axs[1].set_yscale("log")
Expand All @@ -236,7 +236,7 @@ import jax
hosp_model.run(
num_samples=2000,
num_warmup=2000,
observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down Expand Up @@ -281,7 +281,7 @@ dat_w_padding = np.pad(
hosp_model.run(
num_samples=2000,
num_warmup=2000,
observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute, # Padding the model
Expand Down Expand Up @@ -382,7 +382,7 @@ Running the model (with the same padding as before):
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute,
Expand Down
4 changes: 2 additions & 2 deletions model/docs/getting_started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.sampled_observed_infections)
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
Expand All @@ -190,7 +190,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
observed_infections=sim_data.sampled_observed_infections,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
4 changes: 2 additions & 2 deletions model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ fig, ax = plt.subplots(nrows=3, sharex=True)
ax[0].plot(x.latent_infections)
ax[0].set_ylim([1 / 5, 5])
ax[1].plot(x.latent_hosp_admissions)
ax[2].plot(x.sampled_observed_hosp_admissions, "o")
ax[2].plot(x.observed_hosp_admissions, "o")
for axis in ax[:-1]:
axis.set_yscale("log")
```
Expand All @@ -174,7 +174,7 @@ To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC
hospmodel.run(
num_warmup=1000,
num_samples=1000,
observed_hosp_admissions=x.sampled_observed_hosp_admissions,
data_observed_hosp_admissions=x.observed_hosp_admissions,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
44 changes: 22 additions & 22 deletions model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ class HospModelSample(NamedTuple):
The infected hospitalization rate. Defaults to None.
latent_hosp_admissions : ArrayLike | None, optional
The estimated latent hospitalizations. Defaults to None.
sampled_observed_hosp_admissions : ArrayLike | None, optional
observed_hosp_admissions : ArrayLike | None, optional
The sampled or observed hospital admissions. Defaults to None.
"""

Rt: float | None = None
latent_infections: ArrayLike | None = None
infection_hosp_rate: float | None = None
latent_hosp_admissions: ArrayLike | None = None
sampled_observed_hosp_admissions: ArrayLike | None = None
observed_hosp_admissions: ArrayLike | None = None

def __repr__(self):
return (
f"HospModelSample(Rt={self.Rt}, "
f"latent_infections={self.latent_infections}, "
f"infection_hosp_rate={self.infection_hosp_rate}, "
f"latent_hosp_admissions={self.latent_hosp_admissions}, "
f"sampled_observed_hosp_admissions={self.sampled_observed_hosp_admissions}"
f"observed_hosp_admissions={self.observed_hosp_admissions}"
)


Expand Down Expand Up @@ -162,7 +162,7 @@ def sample_latent_hosp_admissions(
def sample_admissions_process(
self,
observed_hosp_admissions_mean: ArrayLike,
observed_hosp_admissions: ArrayLike,
data_observed_hosp_admissions: ArrayLike,
name: str | None = None,
**kwargs,
) -> tuple:
Expand All @@ -188,15 +188,15 @@ def sample_admissions_process(

return self.hosp_admission_obs_process_rv.sample(
mu=observed_hosp_admissions_mean,
obs=observed_hosp_admissions,
obs=data_observed_hosp_admissions,
name=name,
**kwargs,
)

def sample(
self,
n_timepoints_to_simulate: int | None = None,
observed_hosp_admissions: ArrayLike | None = None,
data_observed_hosp_admissions: ArrayLike | None = None,
padding: int = 0,
**kwargs,
) -> HospModelSample:
Expand All @@ -207,7 +207,7 @@ def sample(
----------
n_timepoints_to_simulate : int, optional
Number of timepoints to sample (passed to the basic renewal model).
observed_hosp_admissions : ArrayLike, optional
data_observed_hosp_admissions : ArrayLike, optional
The observed hospitalization data (passed to the basic renewal
model). Defaults to None (simulation, rather than fit).
padding : int, optional
Expand All @@ -229,28 +229,28 @@ def sample(
"""
if (
n_timepoints_to_simulate is None
and observed_hosp_admissions is None
and data_observed_hosp_admissions is None
):
raise ValueError(
"Either n_timepoints_to_simulate or observed_hosp_admissions "
"Either n_timepoints_to_simulate or data_observed_hosp_admissions "
"must be passed."
)
elif (
n_timepoints_to_simulate is not None
and observed_hosp_admissions is not None
and data_observed_hosp_admissions is not None
):
raise ValueError(
"Cannot pass both n_timepoints_to_simulate and observed_hosp_admissions."
"Cannot pass both n_timepoints_to_simulate and data_observed_hosp_admissions."
)
elif n_timepoints_to_simulate is None:
n_timepoints = len(observed_hosp_admissions)
n_timepoints = len(data_observed_hosp_admissions)
else:
n_timepoints = n_timepoints_to_simulate

# Getting the initial quantities from the basic model
basic_model = self.basic_renewal.sample(
n_timepoints_to_simulate=n_timepoints,
observed_infections=None,
data_observed_infections=None,
padding=padding,
**kwargs,
)
Expand All @@ -266,33 +266,33 @@ def sample(
)
i0_size = len(latent_hosp_admissions) - n_timepoints
if self.hosp_admission_obs_process_rv is None:
sampled_observed_hosp_admissions = None
observed_hosp_admissions = None
else:
if observed_hosp_admissions is None:
if data_observed_hosp_admissions is None:
(
sampled_observed_hosp_admissions,
observed_hosp_admissions,
*_,
) = self.sample_admissions_process(
observed_hosp_admissions_mean=latent_hosp_admissions,
observed_hosp_admissions=observed_hosp_admissions,
data_observed_hosp_admissions=data_observed_hosp_admissions,
**kwargs,
)
else:
observed_hosp_admissions = au.pad_x_to_match_y(
observed_hosp_admissions,
data_observed_hosp_admissions = au.pad_x_to_match_y(
data_observed_hosp_admissions,
latent_hosp_admissions,
jnp.nan,
pad_direction="start",
)

(
sampled_observed_hosp_admissions,
observed_hosp_admissions,
*_,
) = self.sample_admissions_process(
observed_hosp_admissions_mean=latent_hosp_admissions[
i0_size + padding :
],
observed_hosp_admissions=observed_hosp_admissions[
data_observed_hosp_admissions=data_observed_hosp_admissions[
i0_size + padding :
],
**kwargs,
Expand All @@ -303,5 +303,5 @@ def sample(
latent_infections=basic_model.latent_infections,
infection_hosp_rate=infection_hosp_rate,
latent_hosp_admissions=latent_hosp_admissions,
sampled_observed_hosp_admissions=sampled_observed_hosp_admissions,
observed_hosp_admissions=observed_hosp_admissions,
)
53 changes: 29 additions & 24 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ class RtInfectionsRenewalSample(NamedTuple):
The reproduction number over time. Defaults to None.
latent_infections : ArrayLike | None, optional
The estimated latent infections. Defaults to None.
sampled_observed_infections : ArrayLike | None, optional
observed_infections : ArrayLike | None, optional
The sampled infections. Defaults to None.
"""

Rt: float | None = None
latent_infections: ArrayLike | None = None
sampled_observed_infections: ArrayLike | None = None
observed_infections: ArrayLike | None = None

def __repr__(self):
return (
f"RtInfectionsRenewalSample(Rt={self.Rt}, "
f"latent_infections={self.latent_infections}, "
f"sampled_observed_infections={self.sampled_observed_infections})"
f"observed_infections={self.observed_infections})"
)


Expand Down Expand Up @@ -216,7 +216,7 @@ def sample_infections_latent(
def sample_infection_obs_process(
self,
observed_infections_mean: ArrayLike,
observed_infections: ArrayLike | None = None,
data_observed_infections: ArrayLike | None = None,
name: str | None = None,
**kwargs,
) -> tuple:
Expand All @@ -229,7 +229,7 @@ def sample_infection_obs_process(
----------
observed_infections_mean : ArrayLike
The mean of the observed infections distribution.
observed_infections : ArrayLike | None, optional
data_observed_infections : ArrayLike | None, optional
The observed infection values, if any, for inference. Defaults to None.
name : str | None, optional
Name of the random variable passed to the RandomVariable. Defaults to None.
Expand All @@ -243,15 +243,15 @@ def sample_infection_obs_process(
"""
return self.infection_obs_process_rv.sample(
mu=observed_infections_mean,
obs=observed_infections,
obs=data_observed_infections,
name=name,
**kwargs,
)

def sample(
self,
n_timepoints_to_simulate: int | None = None,
observed_infections: ArrayLike | None = None,
data_observed_infections: ArrayLike | None = None,
padding: int = 0,
**kwargs,
) -> RtInfectionsRenewalSample:
Expand All @@ -262,7 +262,7 @@ def sample(
----------
n_timepoints_to_simulate : int, optional
Number of timepoints to sample.
observed_infections : ArrayLike | None, optional
data_observed_infections : ArrayLike | None, optional
Observed infections. Defaults to None.
padding : int, optional
Number of padding timepoints to add to the beginning of the
Expand All @@ -273,27 +273,30 @@ def sample(
Notes
-----
Either `observed_admissions` or `n_timepoints_to_simulate` must be specified, not both.
Either `data_observed_infections` or `n_timepoints_to_simulate` must be specified, not both.
Returns
-------
RtInfectionsRenewalSample
"""

if n_timepoints_to_simulate is None and observed_infections is None:
if (
n_timepoints_to_simulate is None
and data_observed_infections is None
):
raise ValueError(
"Either n_timepoints_to_simulate or observed_infections "
"Either n_timepoints_to_simulate or data_observed_infections "
"must be passed."
)
elif (
n_timepoints_to_simulate is not None
and observed_infections is not None
and data_observed_infections is not None
):
raise ValueError(
"Cannot pass both n_timepoints_to_simulate and observed_infections."
"Cannot pass both n_timepoints_to_simulate and data_observed_infections."
)
elif n_timepoints_to_simulate is None:
n_timepoints = len(observed_infections)
n_timepoints = len(data_observed_infections)
else:
n_timepoints = n_timepoints_to_simulate
# Sampling from Rt (possibly with a given Rt, depending on
Expand All @@ -317,36 +320,38 @@ def sample(
**kwargs,
)

if observed_infections is None:
if data_observed_infections is None:
(
sampled_observed_infections,
observed_infections,
*_,
) = self.sample_infection_obs_process(
observed_infections_mean=latent_infections,
observed_infections=observed_infections,
data_observed_infections=data_observed_infections,
**kwargs,
)
else:
observed_infections = au.pad_x_to_match_y(
observed_infections,
data_observed_infections = au.pad_x_to_match_y(
data_observed_infections,
latent_infections,
jnp.nan,
pad_direction="start",
)

(
sampled_observed_infections,
observed_infections,
*_,
) = self.sample_infection_obs_process(
observed_infections_mean=latent_infections[
I0_size + padding :
],
observed_infections=observed_infections[I0_size + padding :],
data_observed_infections=data_observed_infections[
I0_size + padding :
],
**kwargs,
)

sampled_observed_infections = au.pad_x_to_match_y(
sampled_observed_infections,
observed_infections = au.pad_x_to_match_y(
observed_infections,
latent_infections,
jnp.nan,
pad_direction="start",
Expand All @@ -358,5 +363,5 @@ def sample(
return RtInfectionsRenewalSample(
Rt=Rt,
latent_infections=latent_infections,
sampled_observed_infections=sampled_observed_infections,
observed_infections=observed_infections,
)
Loading

0 comments on commit 657a206

Please sign in to comment.