diff --git a/model/docs/example-with-datasets.qmd b/model/docs/example-with-datasets.qmd index e8706469..b59a92e0 100644 --- a/model/docs/example-with-datasets.qmd +++ b/model/docs/example-with-datasets.qmd @@ -126,7 +126,7 @@ from pyrenew import latent, deterministic, metaclass import jax.numpy as jnp import numpyro.distributions as dist -inf_hosp_int = deterministic.DeterministicPMF(inf_hosp_int) +inf_hosp_int = deterministic.DeterministicPMF(inf_hosp_int, name="inf_hosp_int") hosp_rate = metaclass.DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.1), @@ -153,7 +153,7 @@ I0 = metaclass.DistributionalRV( ) # Generation interval and Rt -gen_int = deterministic.DeterministicPMF(gen_int) +gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") rtproc = process.RtRandomWalkProcess( Rt_rw_dist=dist.Normal(0, 0.1) ) diff --git a/model/docs/extending_pyrenew.qmd b/model/docs/extending_pyrenew.qmd index 2076d054..02919f0a 100644 --- a/model/docs/extending_pyrenew.qmd +++ b/model/docs/extending_pyrenew.qmd @@ -36,8 +36,8 @@ The following code-chunk defines the model components. Notice that for both the ```{python} #| label: model-components -gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1])) -feedback_strength = DeterministicVariable(0.05) +gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1]), name="gen_int") +feedback_strength = DeterministicVariable(0.05, name="feedback_strength") I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") diff --git a/model/docs/getting-started.qmd b/model/docs/getting-started.qmd index ece4ee8a..3b3bbe51 100644 --- a/model/docs/getting-started.qmd +++ b/model/docs/getting-started.qmd @@ -84,7 +84,7 @@ To initialize these five components within the renewal modeling framework, we es ```{python} #| label: creating-elements # (1) The generation interval (deterministic) -gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) +gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int") # (2) Initial infections (inferred with a prior) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") diff --git a/model/docs/pyrenew_demo.qmd b/model/docs/pyrenew_demo.qmd index a64c5fcb..defec1cb 100644 --- a/model/docs/pyrenew_demo.qmd +++ b/model/docs/pyrenew_demo.qmd @@ -86,7 +86,7 @@ To initialize the model, we first define initial conditions, including: # Initializing model components: # 1) A deterministic generation time -gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) +gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int") # 2) Initial infections I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -99,6 +99,7 @@ latent_infections = Infections() # First, define a deterministic infection to hosp pmf inf_hosp_int = DeterministicPMF( jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]), + name="inf_hosp_int" ) latent_admissions = HospitalAdmissions( diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 1e0bcc04..cd602212 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -4,6 +4,7 @@ from __future__ import annotations import jax.numpy as jnp +import numpyro as npro from jax.typing import ArrayLike from pyrenew.metaclass import RandomVariable @@ -17,7 +18,7 @@ class DeterministicVariable(RandomVariable): def __init__( self, vars: ArrayLike, - label: str = "a_random_variable", + name: str, ) -> None: """Default constructor @@ -25,8 +26,8 @@ def __init__( ---------- vars : ArrayLike A tuple with arraylike objects. - label : str, optional - A label to assign to the process. Defaults to "a_random_variable" + name : str, optional + A name to assign to the process. Returns ------- @@ -35,7 +36,7 @@ def __init__( self.validate(vars) self.vars = jnp.atleast_1d(vars) - self.label = label + self.name = name return None @@ -65,6 +66,7 @@ def validate(vars: ArrayLike) -> None: def sample( self, + record=True, **kwargs, ) -> tuple: """ @@ -72,6 +74,8 @@ def sample( Parameters ---------- + record : bool, optional + Whether to record the value of the deterministic RandomVariable. Defaults to True. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. @@ -81,5 +85,6 @@ def sample( tuple Containing the stored values during construction. """ - + if record: + npro.deterministic(self.name, self.vars) return (self.vars,) diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index c31cfaa7..0d62208b 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -16,8 +16,8 @@ class DeterministicPMF(RandomVariable): def __init__( self, vars: ArrayLike, - label: str = "a_random_variable", - tol: float = 1e-20, + name: str, + tol: float = 1e-5, ) -> None: """ Default constructor @@ -31,11 +31,11 @@ def __init__( ---------- vars : tuple A tuple with arraylike objects. - label : str, optional - A label to assign to the process. Defaults to "a_random_variable" + name : str + A name to assign to the process. tol : float, optional Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults - to 1e-20. + to 1e-5. Returns ------- @@ -46,7 +46,7 @@ def __init__( tol=tol, ) - self.basevar = DeterministicVariable(vars, label) + self.basevar = DeterministicVariable(vars, name) return None diff --git a/model/src/pyrenew/distutil.py b/model/src/pyrenew/distutil.py index 30f042c8..206d3e80 100755 --- a/model/src/pyrenew/distutil.py +++ b/model/src/pyrenew/distutil.py @@ -15,7 +15,7 @@ def validate_discrete_dist_vector( - discrete_dist: ArrayLike, tol: float = 1e-20 + discrete_dist: ArrayLike, tol: float = 1e-5 ) -> ArrayLike: """ Validate that a vector represents a discrete @@ -30,7 +30,7 @@ def validate_discrete_dist_vector( must sum to 1 within the specified tolerance. tol : float, optional The tolerance within which the sum of the distribution must - be 1. Defaults to 1e-20. + be 1. Defaults to 1e-5. Returns ------- diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 717668a2..cba1b0f7 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -94,9 +94,11 @@ def __init__( """ if weekday_effect_dist is None: - weekday_effect_dist = DeterministicVariable(1) + weekday_effect_dist = DeterministicVariable(1, "weekday_effect") if hosp_report_prob_dist is None: - hosp_report_prob_dist = DeterministicVariable(1) + hosp_report_prob_dist = DeterministicVariable( + 1, "hosp_report_prob" + ) HospitalAdmissions.validate( infect_hosp_rate_dist, diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index b59dfdd1..bf54e48e 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -22,10 +22,11 @@ def test_deterministic(): [ 1, ] - ) + ), + name="var1", ) - var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3])) - var3 = DeterministicProcess(jnp.array([1, 2, 3, 4])) + var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3]), name="var2") + var3 = DeterministicProcess(jnp.array([1, 2, 3, 4]), name="var3") var4 = NullVariable() var5 = NullProcess() diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index ed1e9434..856f17f5 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -68,8 +68,10 @@ def test_infectionsrtfeedback(): # By doing the infection feedback strength 0, Rt = Rt_adjusted # So infection should be equal in both - inf_feed_strength = DeterministicVariable(jnp.zeros_like(Rt)) - inf_feedback_pmf = DeterministicPMF(gen_int) + inf_feed_strength = DeterministicVariable( + jnp.zeros_like(Rt), name="inf_feed_strength" + ) + inf_feedback_pmf = DeterministicPMF(gen_int, name="inf_feedback_pmf") # Test the InfectionsWithFeedback class InfectionsWithFeedback = latent.InfectionsWithFeedback( @@ -107,8 +109,10 @@ def test_infectionsrtfeedback_feedback(): I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) - inf_feed_strength = DeterministicVariable(jnp.repeat(0.5, len(Rt))) - inf_feedback_pmf = DeterministicPMF(gen_int) + inf_feed_strength = DeterministicVariable( + jnp.repeat(0.5, len(Rt)), name="inf_feed_strength" + ) + inf_feedback_pmf = DeterministicPMF(gen_int, name="inf_feedback_pmf") # Test the InfectionsWithFeedback class InfectionsWithFeedback = latent.InfectionsWithFeedback( diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 144654aa..60a0b813 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -56,6 +56,7 @@ def test_admissions_sample(): 0.05, ] ), + name="inf_hosp", ) hosp1 = HospitalAdmissions( diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 0b8087a1..3210bf03 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -23,7 +23,9 @@ def test_model_basicrenewal_no_obs_model(): from the perspective of the infections. It returns expected, not sampled. """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) with pytest.raises(ValueError): I0 = DistributionalRV(dist=1, name="I0") @@ -88,7 +90,9 @@ def test_model_basicrenewal_with_obs_model(): from the perspective of the infections. It returns sampled, not expected. """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -152,7 +156,9 @@ def test_model_basicrenewal_plot() -> plt.Figure: This will skip validating the figure and save the new figure in the `src/test/baseline` folder. """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -190,7 +196,9 @@ def test_model_basicrenewal_plot() -> plt.Figure: def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index e0314639..438cff46 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -41,7 +41,9 @@ def test_model_hosp_no_obs_model(): Hospitalization model runs """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -70,6 +72,7 @@ def test_model_hosp_no_obs_model(): 0.05, ] ), + name="inf_hosp", ) latent_admissions = HospitalAdmissions( @@ -137,7 +140,9 @@ def test_model_hosp_with_obs_model(): Checks that the random Hospitalization model runs """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -168,6 +173,7 @@ def test_model_hosp_with_obs_model(): 0.05, ], ), + name="inf_hosp", ) latent_admissions = HospitalAdmissions( @@ -216,7 +222,9 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): Checks that the random Hospitalization model runs """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -247,6 +255,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): 0.05, ], ), + name="inf_hosp", ) # Other random components @@ -306,7 +315,9 @@ def test_model_hosp_with_obs_model_weekday_phosp(): Checks that the random Hospitalization model runs """ - gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25])) + gen_int = DeterministicPMF( + jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" + ) I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0") @@ -337,6 +348,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): 0.05, ], ), + name="inf_hosp", ) # Other random components @@ -345,7 +357,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): weekday = weekday / weekday.sum() weekday = weekday[:31] - weekday = DeterministicVariable(weekday) + weekday = DeterministicVariable(weekday, name="weekday") hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4]) hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10) @@ -353,7 +365,9 @@ def test_model_hosp_with_obs_model_weekday_phosp(): hosp_report_prob_dist = hosp_report_prob_dist[:31] - hosp_report_prob_dist = DeterministicVariable(vars=hosp_report_prob_dist) + hosp_report_prob_dist = DeterministicVariable( + vars=hosp_report_prob_dist, name="hosp_report_prob_dist" + ) latent_admissions = HospitalAdmissions( infection_to_admission_interval=inf_hosp,