Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First Pass Documentation (docstrings) For Latent Folder #73

Merged
merged 13 commits into from
Apr 15, 2024
12 changes: 8 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ repos:
- id: isort
args: ['--profile', 'black',
'--line-length', '79']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
hooks:
- id: ruff
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
hooks:
- id: ruff
- repo: https://github.com/numpy/numpydoc
rev: v1.6.0
hooks:
- id: numpydoc-validation
#####
# Secrets
- repo: https://github.com/Yelp/detect-secrets
Expand Down
567 changes: 558 additions & 9 deletions model/poetry.lock

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions model/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,26 @@ polars = "^0.20.13"
matplotlib = "^3.8.3"
pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference
ipykernel = "^6.29.3"
numpydoc = "^1.6.0"
gvegayon marked this conversation as resolved.
Show resolved Hide resolved


[tool.numpydoc_validation]
checks = [
"EX01",
"SA01",
"ES01",
]
exclude = [ # don't report on objects that match any of these regex
'\.undocumented_method$',
'\.__repr__$',
'\.__init__$',
]
override_SS05 = [ # override SS05 to allow docstrings starting with these words
'^Process ',
'^Assess ',
'^Access ',
]

gvegayon marked this conversation as resolved.
Show resolved Hide resolved
[tool.poetry.group.test.dependencies]
pytest = "^8.0.0"

Expand Down
122 changes: 107 additions & 15 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

from collections import namedtuple
from typing import Any, Optional

import jax.numpy as jnp
import numpyro as npro
Expand All @@ -14,30 +15,65 @@
["IHR", "predicted"],
defaults=[None, None],
)
"""Output from HospitalAdmissions.sample()"""
HospAdmissionsSample.__doc__ = """
A container for holding the output from HospAdmissionsSample.sample.

Attributes
----------
IHR : float or None
The infected hospitalization rate. Defaults to None.
predicted : ArrayLike or None
The predicted number of hospital admissions. Defaults to None.

Notes
-----
TODO: Add Notes.
"""

InfectHospRateSample = namedtuple(
"InfectHospRateSample",
["IHR"],
defaults=[None],
)
InfectHospRateSample.__doc__ = """
A container for holding the output from InfectHospRateSample.sample.

Attributes
----------
IHR : ArrayLike or None
The infected hospitalization rate. Defaults to None.

Notes
-----
TODO: Add Notes.
"""


class InfectHospRate(RandomVariable):
"""Infection to Hospitalization Rate"""
"""
Infection to Hospitalization Rate

Methods
-------
validate(distr)
Validates distribution is Numpyro distribution
sample(**kwargs)
Produces a sample of the IHR
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
dist: dist.Distribution,
varname: str = "IHR",
dist: Optional[dist.Distribution] = dist.LogNormal(jnp.log(0.05), 0.05),
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
varname: Optional[str] = "IHR",
) -> None:
"""Default constructor
"""
Default constructor
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
dist : dist.Distribution, optional
Prior distribution of the IHR, by default
dist.LogNormal(jnp.log(0.05), 0.05)
dist.LogNormal(jnp.log(0.05), 0.05).
varname : str, optional
Name of the random variable in the model, by default "IHR."

Expand All @@ -47,17 +83,43 @@ def __init__(
"""

self.validate(dist)

self.dist = dist
self.varname = varname

return None

@staticmethod
def validate(distr: dist.Distribution) -> None:
"""
Validates distribution is Numpyro distribution

Parameter
---------
distr : dist.Distribution
A ingested distribution (e.g., prior IHR distribution)

Raises
------
AssertionError
If the inputted distribution is not a Numpyro distribution.
"""
assert isinstance(distr, dist.Distribution)

def sample(self, **kwargs) -> InfectHospRateSample:
"""
Produces a sample of the IHR

Parameter
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
---------
**kwargs : dict, optional
Additional keyword arguments passed through to internal
sample calls, should there be any.

Returns
-------
InfectHospRateSample
The sampled IHR
"""
return InfectHospRateSample(
npro.sample(
name=self.varname,
Expand All @@ -71,12 +133,18 @@ class HospitalAdmissions(RandomVariable):

Implements a renewal process for the expected number of hospitalizations.

Methods
-------
validate(infect_hosp_rate_dist, weekday_effect_dist, hosp_report_prob_dist)
Validates that the IHR, weekday effects, and probability of being
reported hospitalized distributions are RandomVariable types
sample(latent, **kwargs)
Samples from the observation process

Notes
-----

The following text was directly extracted from the wastewater model
documentation
(`link <https://github.com/cdcent/cfa-forecast-renewal-ww/blob/a17efc090b2ffbc7bc11bdd9eec5198d6bcf7322/model_definition.md#hospital-admissions-component> `_).
documentation (`link <https://github.com/cdcent/cfa-forecast-renewal-ww/blob/a17efc090b2ffbc7bc11bdd9eec5198d6bcf7322/model_definition.md#hospital-admissions-component> `_).

Following other semi-mechanistic renewal frameworks, we model the _expected_
hospital admissions per capita :math:`H(t)` as a convolution of the
Expand All @@ -102,8 +170,8 @@ def __init__(
infection_to_admission_interval: RandomVariable,
infect_hosp_rate_dist: RandomVariable,
hospitalizations_predicted_varname: str = "predicted_hospitalizations",
weekday_effect_dist: RandomVariable = None,
hosp_report_prob_dist: RandomVariable = None,
weekday_effect_dist: Optional[RandomVariable] = None,
hosp_report_prob_dist: Optional[RandomVariable] = None,
) -> None:
"""Default constructor

Expand Down Expand Up @@ -150,10 +218,34 @@ def __init__(

@staticmethod
def validate(
infect_hosp_rate_dist,
weekday_effect_dist,
hosp_report_prob_dist,
infect_hosp_rate_dist: Any,
weekday_effect_dist: Any,
hosp_report_prob_dist: Any,
) -> None:
"""
Validates that the IHR, weekday effects, and probability of being
reported hospitalized distributions are RandomVariable types

Parameters
----------
infect_hosp_rate_dist : Any
Possibly incorrect input for infection to hospitalization rate distribution.
weekday_effect_dist : Any
Possibly incorrect input for weekday effect.
hosp_report_prob_dist : Any
Possibly incorrect input for distribution or fixed value for the
hospital admission reporting probability.

Returns
-------
None

Raises
------
AssertionError
If the object `distr` is not an instance of `dist.Distribution`, indicating
that the validation has failed.
"""
assert isinstance(infect_hosp_rate_dist, RandomVariable)
assert isinstance(weekday_effect_dist, RandomVariable)
assert isinstance(hosp_report_prob_dist, RandomVariable)
Expand Down
26 changes: 20 additions & 6 deletions model/src/pyrenew/latent/i0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@
import numpyro.distributions as dist
from pyrenew.metaclass import RandomVariable

from tpying import Optional, Any

class Infections0(RandomVariable):
"""Initial infections helper class.

It creates a random variable for the initial infections with a prior
distribution.

Methods
-------
validate(i0_dist)
Validate the initial infections distribution.
sample(**kwargs)
Sample the initial infections.
"""

def __init__(
self,
name: str = "I0",
I0_dist: dist.Distribution = dist.LogNormal(0, 1),
name: Optional[str] = "I0",
I0_dist: Optional[dist.Distribution] = dist.LogNormal(0, 1),
) -> None:
"""Default constructor

Expand All @@ -36,17 +44,22 @@ def __init__(
return None

@staticmethod
def validate(i0_dist):
def validate(i0_dist: Any):
"""Validate the initial infections distribution.

Parameters
----------
i0_dist : dist.Distribution
Distribution of the initial infections.
i0_dist : Any
Distribution (expected dist.Distribution) of the initial infections.

Returns
-------
None

Raises
------
AssertionError
If the inputted distribution is not a Numpyro distribution.
"""
assert isinstance(i0_dist, dist.Distribution)

Expand All @@ -59,7 +72,8 @@ def sample(
Parameters
----------
**kwargs : dict, optional
Ignored
Additional keyword arguments passed through to internal
sample calls, should there be any.

Returns
-------
Expand Down
25 changes: 9 additions & 16 deletions model/src/pyrenew/latent/infection_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from jax.typing import ArrayLike
from pyrenew.convolve import new_convolve_scanner, new_double_scanner

"""
infection
# """
# infection

Functions for sampling timeseries of
infections
"""
# Functions for sampling timeseries of
# infections
# # UPX3 comment given uncertainty of this docstring
# """


def sample_infections_rt(
I0: ArrayLike, Rt: ArrayLike, reversed_generation_interval_pmf: ArrayLike
):
) -> ArrayLike:
"""
Sample infections according to a
renewal process with a time-varying
Expand All @@ -27,10 +28,8 @@ def sample_infections_rt(
Array of initial infections of the
same length as the generation inferval
pmf vector.

Rt: ArrayLike
Timeseries of R(t) values

reversed_generation_interval_pmf: ArrayLike
discrete probability mass vector
representing the generation interval
Expand All @@ -54,7 +53,7 @@ def logistic_susceptibility_adjustment(
I_raw_t: float,
frac_susceptible: float,
n_population: float,
):
) -> float:
"""
Apply the logistic susceptibility
adjustment to a potential new
Expand All @@ -67,11 +66,9 @@ def logistic_susceptibility_adjustment(
The "unadjusted" incidence at time t,
i.e. the incidence given an infinite
number of available susceptible individuals.

frac_susceptible : float
fraction of remainin susceptible individuals
in the population

n_population : float
Total size of the population.

Expand All @@ -97,7 +94,7 @@ def sample_infections_with_feedback(
infection_feedback_strength: ArrayLike,
generation_interval_pmf: ArrayLike,
infection_feedback_pmf: ArrayLike,
):
) -> tuple:
"""
Sample infections according to
a renewal process with infection
Expand All @@ -110,23 +107,19 @@ def sample_infections_with_feedback(
Array of initial infections of the
same length as the generation inferval
pmf vector.

Rt_raw: ArrayLike
Timeseries of raw R(t) values not
adjusted by infection feedback

infection_feedback_strength: ArrayLike
Strength of the infection feedback.
Either a scalar (constant feedback
strength in time) or a vector representing
the infection feedback strength at a
given point in time.

generation_interval_pmf: ArrayLike
discrete probability mass vector
representing the generation interval
of the infection process

infection_feedback_pmf: ArrayLike
discrete probability mass vector
whose `i`th entry represents the
Expand Down
Loading
Loading