Skip to content

Commit

Permalink
Clean up Rt infections renewal model (#204)
Browse files Browse the repository at this point in the history
* clean up rt infections renewal model

* get rid of image test

* fix test
  • Loading branch information
damonbayer authored Jun 20, 2024
1 parent 63ed209 commit 71228e5
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 131 deletions.
4 changes: 0 additions & 4 deletions model/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ docs_clean:
rm -rf docs/*_files/
rm -f $(MD_FILES) $(IPYNB_FILES) $(PY_FILES)

test_images:
echo "Generating reference images for tests"
poetry run pytest --mpl-generate-path=src/test/baseline

image-build: Dockerfile
$(CONTAINER) build -t pyrenew:latest .

Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def sample(
infection_hosp_rate_t = infection_hosp_rate * latent_infections

(
infection_to_admission_interval_rv,
infection_to_admission_interval,
*_,
) = self.infection_to_admission_interval_rv.sample(**kwargs)

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate_t,
infection_to_admission_interval_rv,
infection_to_admission_interval,
mode="full",
)[: infection_hosp_rate_t.shape[0]]

Expand Down
13 changes: 4 additions & 9 deletions model/src/pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import NamedTuple

import jax.numpy as jnp
import numpyro as npro
import pyrenew.latent.infection_functions as inf
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
Expand All @@ -18,11 +17,11 @@ class InfectionsSample(NamedTuple):
Attributes
----------
infections : ArrayLike | None, optional
post_seed_infections : ArrayLike | None, optional
The estimated latent infections. Defaults to None.
"""

infections: ArrayLike | None = None
post_seed_infections: ArrayLike | None = None

def __repr__(self):
return f"InfectionsSample(infections={self.infections})"
Expand Down Expand Up @@ -114,14 +113,10 @@ def sample(
gen_int_rev = jnp.flip(gen_int)
recent_I0 = I0[-gen_int_rev.size :]

all_infections = inf.compute_infections_from_rt(
post_seed_infections = inf.compute_infections_from_rt(
I0=recent_I0,
Rt=Rt,
reversed_generation_interval_pmf=gen_int_rev,
)

all_infections = jnp.hstack([I0, all_infections])

npro.deterministic(self.infections_mean_varname, all_infections)

return InfectionsSample(all_infections)
return InfectionsSample(post_seed_infections)
14 changes: 8 additions & 6 deletions model/src/pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ class InfectionsRtFeedbackSample(NamedTuple):
Attributes
----------
infections : ArrayLike | None, optional
post_seed_infections : ArrayLike | None, optional
The estimated latent infections. Defaults to None.
rt : ArrayLike | None, optional
The adjusted reproduction number. Defaults to None.
"""

infections: ArrayLike | None = None
post_seed_infections: ArrayLike | None = None
rt: ArrayLike | None = None

def __repr__(self):
return f"InfectionsSample(infections={self.infections}, rt={self.rt})"
return f"InfectionsSample(post_seed_infections={self.post_seed_infections}, rt={self.rt})"


class InfectionsWithFeedback(RandomVariable):
Expand Down Expand Up @@ -187,7 +187,10 @@ def sample(

inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

all_infections, Rt_adj = inf.compute_infections_from_rt_with_feedback(
(
post_seed_infections,
Rt_adj,
) = inf.compute_infections_from_rt_with_feedback(
I0=I0,
Rt_raw=Rt,
infection_feedback_strength=inf_feedback_strength,
Expand All @@ -196,11 +199,10 @@ def sample(
)

# Appending initial infections to the infections
all_infections = jnp.hstack([I0, all_infections])

npro.deterministic("Rt_adjusted", Rt_adj)

return InfectionsRtFeedbackSample(
infections=all_infections,
post_seed_infections=post_seed_infections,
rt=Rt_adj,
)
60 changes: 23 additions & 37 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import NamedTuple

import jax.numpy as jnp
import numpyro as npro
import pyrenew.arrayutils as au
from numpy.typing import ArrayLike
from pyrenew.deterministic import NullObservation
Expand All @@ -19,21 +20,22 @@ class RtInfectionsRenewalSample(NamedTuple):
Attributes
----------
Rt : float | None, optional
Rt : ArrayLike | None, optional
The reproduction number over time. Defaults to None.
latent_infections : ArrayLike | None, optional
The estimated latent infections. Defaults to None.
observed_infections : ArrayLike | None, optional
The sampled infections. Defaults to None.
"""

Rt: float | None = None
Rt: ArrayLike | None = None
latent_infections: ArrayLike | None = None
observed_infections: ArrayLike | None = None

def __repr__(self):
return (
f"RtInfectionsRenewalSample(Rt={self.Rt}, "
f"RtInfectionsRenewalSample("
f"Rt={self.Rt}, "
f"latent_infections={self.latent_infections}, "
f"observed_infections={self.observed_infections})"
)
Expand Down Expand Up @@ -311,57 +313,41 @@ def sample(

# Sampling initial infections
I0, *_ = self.sample_I0(**kwargs)
I0_size = I0.size
# Sampling from the latent process
latent_infections, *_ = self.sample_infections_latent(
post_seed_latent_infections, *_ = self.sample_infections_latent(
Rt=Rt,
gen_int=gen_int,
I0=I0,
**kwargs,
)

if data_observed_infections is None:
(
observed_infections,
*_,
) = self.sample_infection_obs_process(
observed_infections_mean=latent_infections,
data_observed_infections=data_observed_infections,
**kwargs,
)
else:
data_observed_infections = au.pad_x_to_match_y(
data_observed_infections,
latent_infections,
jnp.nan,
pad_direction="start",
)
if data_observed_infections is not None:
data_observed_infections = data_observed_infections[padding:]

(
observed_infections,
*_,
) = self.sample_infection_obs_process(
observed_infections_mean=latent_infections[
I0_size + padding :
],
data_observed_infections=data_observed_infections[
I0_size + padding :
],
**kwargs,
)
observed_infections, *_ = self.sample_infection_obs_process(
observed_infections_mean=post_seed_latent_infections[padding:],
data_observed_infections=data_observed_infections,
**kwargs,
)

all_latent_infections = jnp.hstack([I0, post_seed_latent_infections])
npro.deterministic("latent_infections", all_latent_infections)

observed_infections = au.pad_x_to_match_y(
observed_infections,
latent_infections,
all_latent_infections,
jnp.nan,
pad_direction="start",
)

Rt = au.pad_x_to_match_y(
Rt, latent_infections, jnp.nan, pad_direction="start"
Rt,
all_latent_infections,
jnp.nan,
pad_direction="start",
)

return RtInfectionsRenewalSample(
Rt=Rt,
latent_infections=latent_infections,
latent_infections=all_latent_infections,
observed_infections=observed_infections,
)
Binary file not shown.
12 changes: 8 additions & 4 deletions model/src/test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _infection_w_feedback_alt(
I_vec[t : t + len_gen], np.flip(gen_int)
)

return {"infections": I_vec, "rt": Rt_adj}
return {"post_seed_infections": I_vec[I0.size :], "rt": Rt_adj}


def test_infectionsrtfeedback():
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_infectionsrtfeedback():
I0=I0,
)

assert_array_equal(samp1.infections, samp2.infections)
assert_array_equal(samp1.post_seed_infections, samp2.post_seed_infections)
assert_array_equal(samp1.rt, Rt)

return None
Expand Down Expand Up @@ -143,8 +143,12 @@ def test_infectionsrtfeedback_feedback():
inf_feedback_pmf=inf_feedback_pmf.sample()[0],
)

assert not jnp.array_equal(samp1.infections, samp2.infections)
assert_array_almost_equal(samp1.infections, res["infections"])
assert not jnp.array_equal(
samp1.post_seed_infections, samp2.post_seed_infections
)
assert_array_almost_equal(
samp1.post_seed_infections, res["post_seed_infections"]
)
assert_array_almost_equal(samp1.rt, res["rt"])

return None
2 changes: 1 addition & 1 deletion model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_infections_as_deterministic():
inf_sampled2 = inf1.sample(**obs)

testing.assert_array_equal(
inf_sampled1.infections, inf_sampled2.infections
inf_sampled1.post_seed_infections, inf_sampled2.post_seed_infections
)

# Check that Initial infections vector must be at least as long as the generation interval.
Expand Down
68 changes: 0 additions & 68 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
Expand Down Expand Up @@ -234,73 +233,6 @@ def test_model_basicrenewal_with_obs_model():
assert inf_mean.to_numpy().shape[0] == 500


@pytest.mark.mpl_image_compare
def test_model_basicrenewal_plot() -> plt.Figure:
"""
Check that the posterior sample looks the same (reproducibility)
Returns
-------
plt.Figure
The figure object
Notes
-----
IMPORTANT: If this test fails, it may be that you need
to regenerate the figures. To do so, you can the test using the following
command:
poetry run pytest --mpl-generate-path=src/test/baseline
This will skip validating the figure and save the new figure in the
`src/test/baseline` folder.
"""
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = InfectionSeedingProcess(
"I0_seeding",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
SeedInfectionsZeroPad(n_timepoints=gen_int.size()),
)

latent_infections = Infections()

observed_infections = PoissonObservation()

rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)

model1 = RtInfectionsRenewalModel(
I0_rv=I0,
gen_int_rv=gen_int,
latent_infections_rv=latent_infections,
infection_obs_process_rv=observed_infections,
Rt_process_rv=rt,
)

# Sampling and fitting model 1 (with obs infections)
np.random.seed(2203)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_timepoints_to_simulate=30)

model1.run(
num_warmup=500,
num_samples=500,
rng_key=jr.key(22),
data_observed_infections=model1_samp.observed_infections,
)

return model1.plot_posterior(
var="latent_infections",
obs_signal=model1_samp.observed_infections,
)


def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
Expand Down

0 comments on commit 71228e5

Please sign in to comment.