From a7b6bf987e63c37bfb6254742eba9a5b0db5ae36 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Tue, 10 Sep 2024 10:25:54 -0500 Subject: [PATCH 1/2] reorder arguments and add new test --- pyrenew/convolve.py | 3 ++- pyrenew/latent/hospitaladmissions.py | 2 +- test/test_incidence_observed_with_delay.py | 20 +++++++++++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 44d98e4c..55b2227f 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -11,6 +11,7 @@ :py:func:`jax.lax.scan` with an appropriate array to scan along. """ + from __future__ import annotations from typing import Callable @@ -166,9 +167,9 @@ def _new_scanner( def compute_delay_ascertained_incidence( - p_observed_given_incident: ArrayLike, latent_incidence: ArrayLike, delay_incidence_to_observation_pmf: ArrayLike, + p_observed_given_incident: ArrayLike = 1, ) -> ArrayLike: """ Computes incidences observed according diff --git a/pyrenew/latent/hospitaladmissions.py b/pyrenew/latent/hospitaladmissions.py index 57090528..1fcfc581 100644 --- a/pyrenew/latent/hospitaladmissions.py +++ b/pyrenew/latent/hospitaladmissions.py @@ -211,9 +211,9 @@ def sample( ) = self.infection_to_admission_interval_rv(**kwargs) latent_hospital_admissions = compute_delay_ascertained_incidence( - infection_hosp_rate.value, latent_infections.value, infection_to_admission_interval.value, + infection_hosp_rate.value, ) # Applying the day of the week effect. For this we need to: diff --git a/test/test_incidence_observed_with_delay.py b/test/test_incidence_observed_with_delay.py index 9bff64ad..e1a5145a 100644 --- a/test/test_incidence_observed_with_delay.py +++ b/test/test_incidence_observed_with_delay.py @@ -34,6 +34,12 @@ jnp.array([0.25, 0.5, 0.25]), jnp.array([2]), ], + [ + jnp.array([1.0]), + jnp.array([0, 2.0, 4.0]), + jnp.array([0.25, 0.5, 0.25]), + jnp.array([2]), + ], ], ) def test(obs_rate, latent_incidence, delay_interval, expected_output): @@ -42,8 +48,20 @@ def test(obs_rate, latent_incidence, delay_interval, expected_output): incidence observed with a delay """ result = compute_delay_ascertained_incidence( - obs_rate, latent_incidence, delay_interval, + obs_rate, ) assert_array_equal(result, expected_output) + + +def test_default_obs_rate(): + """ + Tests for helper function to compute + incidence observed with a delay + """ + result = compute_delay_ascertained_incidence( + jnp.array([1.0, 2.0, 3.0]), + jnp.array([1.0]), + ) + assert_array_equal(result, jnp.array([1.0, 2.0, 3.0])) From a0891b4a6a7df32b3bb5dbe4640620298d806251 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Tue, 10 Sep 2024 10:28:44 -0500 Subject: [PATCH 2/2] add test description --- test/test_incidence_observed_with_delay.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_incidence_observed_with_delay.py b/test/test_incidence_observed_with_delay.py index e1a5145a..bcb1ff66 100644 --- a/test/test_incidence_observed_with_delay.py +++ b/test/test_incidence_observed_with_delay.py @@ -57,8 +57,7 @@ def test(obs_rate, latent_incidence, delay_interval, expected_output): def test_default_obs_rate(): """ - Tests for helper function to compute - incidence observed with a delay + Compute incidence observed with a delay and default observation rate """ result = compute_delay_ascertained_incidence( jnp.array([1.0, 2.0, 3.0]),