Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record deterministic RandomVariables by default #148

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model/docs/example-with-datasets.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ from pyrenew import latent, deterministic, metaclass
import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(inf_hosp_int)
inf_hosp_int = deterministic.DeterministicPMF(inf_hosp_int, name="inf_hosp_int")

hosp_rate = metaclass.DistributionalRV(
dist=dist.LogNormal(jnp.log(0.05), 0.1),
Expand All @@ -153,7 +153,7 @@ I0 = metaclass.DistributionalRV(
)

# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(gen_int)
gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int")
rtproc = process.RtRandomWalkProcess(
Rt_rw_dist=dist.Normal(0, 0.1)
)
Expand Down
4 changes: 2 additions & 2 deletions model/docs/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ The following code-chunk defines the model components. Notice that for both the

```{python}
#| label: model-components
gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1]))
feedback_strength = DeterministicVariable(0.05)
gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1]), name="gen_int")
feedback_strength = DeterministicVariable(0.05, name="feedback_strength")

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down
2 changes: 1 addition & 1 deletion model/docs/getting-started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ To initialize these five components within the renewal modeling framework, we es
```{python}
#| label: creating-elements
# (1) The generation interval (deterministic)
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int")

# (2) Initial infections (inferred with a prior)
I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")
Expand Down
3 changes: 2 additions & 1 deletion model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ To initialize the model, we first define initial conditions, including:
# Initializing model components:

# 1) A deterministic generation time
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int")

# 2) Initial infections
I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")
Expand All @@ -99,6 +99,7 @@ latent_infections = Infections()
# First, define a deterministic infection to hosp pmf
inf_hosp_int = DeterministicPMF(
jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),
name="inf_hosp_int"
)

latent_admissions = HospitalAdmissions(
Expand Down
15 changes: 10 additions & 5 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import jax.numpy as jnp
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable

Expand All @@ -17,16 +18,16 @@ class DeterministicVariable(RandomVariable):
def __init__(
self,
vars: ArrayLike,
label: str = "a_random_variable",
name: str,
) -> None:
"""Default constructor

Parameters
----------
vars : ArrayLike
A tuple with arraylike objects.
label : str, optional
A label to assign to the process. Defaults to "a_random_variable"
name : str, optional
A name to assign to the process.

Returns
-------
Expand All @@ -35,7 +36,7 @@ def __init__(

self.validate(vars)
self.vars = jnp.atleast_1d(vars)
self.label = label
self.name = name

return None

Expand Down Expand Up @@ -65,13 +66,16 @@ def validate(vars: ArrayLike) -> None:

def sample(
self,
record=True,
**kwargs,
) -> tuple:
"""
Retrieve the value of the deterministic Rv

Parameters
----------
record : bool, optional
Whether to record the value of the deterministic RandomVariable. Defaults to True.
**kwargs : dict, optional
Additional keyword arguments passed through to internal
sample calls, should there be any.
Expand All @@ -81,5 +85,6 @@ def sample(
tuple
Containing the stored values during construction.
"""

if record:
npro.deterministic(self.name, self.vars)
return (self.vars,)
12 changes: 6 additions & 6 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class DeterministicPMF(RandomVariable):
def __init__(
self,
vars: ArrayLike,
label: str = "a_random_variable",
tol: float = 1e-20,
name: str,
tol: float = 1e-5,
) -> None:
"""
Default constructor
Expand All @@ -31,11 +31,11 @@ def __init__(
----------
vars : tuple
A tuple with arraylike objects.
label : str, optional
A label to assign to the process. Defaults to "a_random_variable"
name : str
A name to assign to the process.
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-20.
to 1e-5.

Returns
-------
Expand All @@ -46,7 +46,7 @@ def __init__(
tol=tol,
)

self.basevar = DeterministicVariable(vars, label)
self.basevar = DeterministicVariable(vars, name)

return None

Expand Down
4 changes: 2 additions & 2 deletions model/src/pyrenew/distutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def validate_discrete_dist_vector(
discrete_dist: ArrayLike, tol: float = 1e-20
discrete_dist: ArrayLike, tol: float = 1e-5
) -> ArrayLike:
"""
Validate that a vector represents a discrete
Expand All @@ -30,7 +30,7 @@ def validate_discrete_dist_vector(
must sum to 1 within the specified tolerance.
tol : float, optional
The tolerance within which the sum of the distribution must
be 1. Defaults to 1e-20.
be 1. Defaults to 1e-5.

Returns
-------
Expand Down
6 changes: 4 additions & 2 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def __init__(
"""

if weekday_effect_dist is None:
weekday_effect_dist = DeterministicVariable(1)
weekday_effect_dist = DeterministicVariable(1, "weekday_effect")
if hosp_report_prob_dist is None:
hosp_report_prob_dist = DeterministicVariable(1)
hosp_report_prob_dist = DeterministicVariable(
1, "hosp_report_prob"
)

HospitalAdmissions.validate(
infect_hosp_rate_dist,
Expand Down
7 changes: 4 additions & 3 deletions model/src/test/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def test_deterministic():
[
1,
]
)
),
name="var1",
)
var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3]))
var3 = DeterministicProcess(jnp.array([1, 2, 3, 4]))
var2 = DeterministicPMF(jnp.array([0.25, 0.25, 0.2, 0.3]), name="var2")
var3 = DeterministicProcess(jnp.array([1, 2, 3, 4]), name="var3")
var4 = NullVariable()
var5 = NullProcess()

Expand Down
12 changes: 8 additions & 4 deletions model/src/test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def test_infectionsrtfeedback():

# By doing the infection feedback strength 0, Rt = Rt_adjusted
# So infection should be equal in both
inf_feed_strength = DeterministicVariable(jnp.zeros_like(Rt))
inf_feedback_pmf = DeterministicPMF(gen_int)
inf_feed_strength = DeterministicVariable(
jnp.zeros_like(Rt), name="inf_feed_strength"
)
inf_feedback_pmf = DeterministicPMF(gen_int, name="inf_feedback_pmf")

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
Expand Down Expand Up @@ -107,8 +109,10 @@ def test_infectionsrtfeedback_feedback():
I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0])

inf_feed_strength = DeterministicVariable(jnp.repeat(0.5, len(Rt)))
inf_feedback_pmf = DeterministicPMF(gen_int)
inf_feed_strength = DeterministicVariable(
jnp.repeat(0.5, len(Rt)), name="inf_feed_strength"
)
inf_feedback_pmf = DeterministicPMF(gen_int, name="inf_feedback_pmf")

# Test the InfectionsWithFeedback class
InfectionsWithFeedback = latent.InfectionsWithFeedback(
Expand Down
1 change: 1 addition & 0 deletions model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_admissions_sample():
0.05,
]
),
name="inf_hosp",
)

hosp1 = HospitalAdmissions(
Expand Down
16 changes: 12 additions & 4 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def test_model_basicrenewal_no_obs_model():
from the perspective of the infections. It returns expected, not sampled.
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

with pytest.raises(ValueError):
I0 = DistributionalRV(dist=1, name="I0")
Expand Down Expand Up @@ -88,7 +90,9 @@ def test_model_basicrenewal_with_obs_model():
from the perspective of the infections. It returns sampled, not expected.
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -152,7 +156,9 @@ def test_model_basicrenewal_plot() -> plt.Figure:
This will skip validating the figure and save the new figure in the
`src/test/baseline` folder.
"""
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -190,7 +196,9 @@ def test_model_basicrenewal_plot() -> plt.Figure:


def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down
26 changes: 20 additions & 6 deletions model/src/test/test_model_hospitalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def test_model_hosp_no_obs_model():
Hospitalization model runs
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -70,6 +72,7 @@ def test_model_hosp_no_obs_model():
0.05,
]
),
name="inf_hosp",
)

latent_admissions = HospitalAdmissions(
Expand Down Expand Up @@ -137,7 +140,9 @@ def test_model_hosp_with_obs_model():
Checks that the random Hospitalization model runs
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -168,6 +173,7 @@ def test_model_hosp_with_obs_model():
0.05,
],
),
name="inf_hosp",
)

latent_admissions = HospitalAdmissions(
Expand Down Expand Up @@ -216,7 +222,9 @@ def test_model_hosp_with_obs_model_weekday_phosp_2():
Checks that the random Hospitalization model runs
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -247,6 +255,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2():
0.05,
],
),
name="inf_hosp",
)

# Other random components
Expand Down Expand Up @@ -306,7 +315,9 @@ def test_model_hosp_with_obs_model_weekday_phosp():
Checks that the random Hospitalization model runs
"""

gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))
gen_int = DeterministicPMF(
jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int"
)

I0 = DistributionalRV(dist=dist.LogNormal(0, 1), name="I0")

Expand Down Expand Up @@ -337,6 +348,7 @@ def test_model_hosp_with_obs_model_weekday_phosp():
0.05,
],
),
name="inf_hosp",
)

# Other random components
Expand All @@ -345,15 +357,17 @@ def test_model_hosp_with_obs_model_weekday_phosp():
weekday = weekday / weekday.sum()
weekday = weekday[:31]

weekday = DeterministicVariable(weekday)
weekday = DeterministicVariable(weekday, name="weekday")

hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4])
hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10)
hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum()

hosp_report_prob_dist = hosp_report_prob_dist[:31]

hosp_report_prob_dist = DeterministicVariable(vars=hosp_report_prob_dist)
hosp_report_prob_dist = DeterministicVariable(
vars=hosp_report_prob_dist, name="hosp_report_prob_dist"
)

latent_admissions = HospitalAdmissions(
infection_to_admission_interval=inf_hosp,
Expand Down
Loading