diff --git a/model/src/pyrenew/latent/__init__.py b/model/src/pyrenew/latent/__init__.py index ecbad7a4..ccb85f1f 100644 --- a/model/src/pyrenew/latent/__init__.py +++ b/model/src/pyrenew/latent/__init__.py @@ -8,9 +8,9 @@ ) from pyrenew.latent.i0 import Infections0 from pyrenew.latent.infection_functions import ( + compute_infections_from_rt, + compute_infections_from_rt_with_feedback, logistic_susceptibility_adjustment, - sample_infections_rt, - sample_infections_with_feedback, ) from pyrenew.latent.infections import Infections @@ -19,7 +19,7 @@ "InfectHospRate", "Infections", "logistic_susceptibility_adjustment", - "sample_infections_rt", - "sample_infections_with_feedback", + "compute_infections_from_rt", + "compute_infections_from_rt_with_feedback", "Infections0", ] diff --git a/model/src/pyrenew/latent/infection_functions.py b/model/src/pyrenew/latent/infection_functions.py index 54523806..7ce8a931 100755 --- a/model/src/pyrenew/latent/infection_functions.py +++ b/model/src/pyrenew/latent/infection_functions.py @@ -9,11 +9,13 @@ from pyrenew.convolve import new_convolve_scanner, new_double_scanner -def sample_infections_rt( - I0: ArrayLike, Rt: ArrayLike, reversed_generation_interval_pmf: ArrayLike +def compute_infections_from_rt( + I0: ArrayLike, + Rt: ArrayLike, + reversed_generation_interval_pmf: ArrayLike, ) -> ArrayLike: """ - Sample infections according to a + Generate infections according to a renewal process with a time-varying reproduction number R(t) @@ -84,7 +86,7 @@ def logistic_susceptibility_adjustment( return n_population * frac_susceptible * approx_frac_infected -def sample_infections_with_feedback( +def compute_infections_from_rt_with_feedback( I0: ArrayLike, Rt_raw: ArrayLike, infection_feedback_strength: ArrayLike, @@ -92,7 +94,7 @@ def sample_infections_with_feedback( infection_feedback_pmf: ArrayLike, ) -> tuple: """ - Sample infections according to + Generate infections according to a renewal process with infection feedback (generalizing Asher 2018: https://doi.org/10.1016/j.epidem.2017.02.009) diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 30f2bc71..5edfbb99 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -107,7 +107,7 @@ def sample( n_lead = gen_int_rev.size - 1 I0_vec = jnp.hstack([jnp.zeros(n_lead), I0]) - all_infections = inf.sample_infections_rt( + all_infections = inf.compute_infections_from_rt( I0=I0_vec, Rt=Rt, reversed_generation_interval_pmf=gen_int_rev,