Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements SampledValue #262

Merged
merged 54 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
41dc41b
Adding information about future features for retrieving timeinfo
gvegayon Jul 10, 2024
c4eef20
Updating to return TimeArray (WIP)
gvegayon Jul 10, 2024
997812c
Working on tests (still 16 fails) [skip ci] expected to fail
gvegayon Jul 10, 2024
a9fbb18
Merge branch 'main' into implements-timearray
gvegayon Jul 11, 2024
752eed2
Merge branch 'main' into implements-timearray
gvegayon Jul 11, 2024
0ddd03e
Down to 5 errors
gvegayon Jul 12, 2024
015e381
Down to 1 error
gvegayon Jul 12, 2024
e073294
Fixing test. Next: merge conflicts
gvegayon Jul 17, 2024
3d7abe1
Fixing merge conflicts. Tests pass
gvegayon Jul 17, 2024
fb85885
Typo [skip ci]
gvegayon Jul 17, 2024
68f0867
Fixing tutorials
gvegayon Jul 17, 2024
813ffdf
Merge branch 'implements-timearray' of github.com:CDCgov/multisignal-…
gvegayon Jul 17, 2024
eba63e9
Working on hosp admin tutorial (expected to fail)
gvegayon Jul 17, 2024
b6c0ff0
Patching tutorial
gvegayon Jul 18, 2024
69d56c1
Making pre-commit happy
gvegayon Jul 18, 2024
1ec391e
Merge branch 'main' into implements-timearray
gvegayon Jul 18, 2024
0fad7e4
Addressing points by @damonbayer and @dylanhmorris
gvegayon Jul 22, 2024
3dff8a5
Merge branch 'implements-timearray' of github.com:CDCgov/multisignal-…
gvegayon Jul 22, 2024
6952683
Renaming TimeArray to SampledValue
gvegayon Jul 23, 2024
e4a13ce
Accessing values instead of arrays in SampledValue
gvegayon Jul 23, 2024
7213bd5
Finalizing merge
gvegayon Jul 23, 2024
6d44a70
Making pre-commit happy
gvegayon Jul 23, 2024
2bba85a
Fixing tutorial
gvegayon Jul 23, 2024
40ccf1e
Merge branch 'main' into implements-timearray
gvegayon Jul 24, 2024
2220fb0
Fixing last merge
gvegayon Jul 24, 2024
cd70eb7
Merge branch 'main' into implements-timearray
gvegayon Jul 24, 2024
a209853
Addressing comments by @damonbayer.
gvegayon Jul 24, 2024
68f1d22
Making pre-commit happy
gvegayon Jul 24, 2024
b5fc9d6
SampledValues are an instance of NamedTuple
gvegayon Jul 24, 2024
ff258d3
Adding suggestions by @dylanhmorris
gvegayon Jul 24, 2024
771f6b5
Missing comment
gvegayon Jul 24, 2024
7a8bf3b
Making pre-commit happy
gvegayon Jul 24, 2024
460993a
Fixing infections with feedback
gvegayon Jul 24, 2024
8f807dc
Removing explicit test call
gvegayon Jul 24, 2024
522c642
Merge branch 'main' into implements-timearray
dylanhmorris Jul 25, 2024
fc17488
Fix pre-commit issues and remaining numpyro/npro conflicts
dylanhmorris Jul 25, 2024
353fbc2
Addressing @dylanhmorris' comments on defaults for t_start and t_unit
gvegayon Jul 25, 2024
f62b9f4
Fixing merge
gvegayon Jul 25, 2024
b9a87f8
Forgot to remove a test call
gvegayon Jul 25, 2024
b34dca6
Merge branch 'main' into implements-timearray
gvegayon Jul 25, 2024
99c9c3b
Merge branch 'main' into implements-timearray
dylanhmorris Jul 25, 2024
6026b19
Autoformat files, fix typo caught by typos hook
dylanhmorris Jul 25, 2024
ee817e5
Update model/src/pyrenew/latent/infectionswithfeedback.py
dylanhmorris Jul 25, 2024
302ca2e
Fix typo in infectionswithfeedback.py that caused ill-formed code and…
dylanhmorris Jul 25, 2024
32cb24d
Fix tutorial bug introduced in merge conflict resolution
dylanhmorris Jul 25, 2024
43012ad
Equality assertion => almost equality assertion
dylanhmorris Jul 25, 2024
4173b75
Set identical seeds
dylanhmorris Jul 25, 2024
f635c0c
Fix another tutorial typo
dylanhmorris Jul 25, 2024
ee9dbe2
Update model/src/pyrenew/deterministic/process.py
dylanhmorris Jul 26, 2024
2177215
Fix type hinting for HospitalAdmissionsSample
dylanhmorris Jul 26, 2024
9ff9dea
Update determinsiticprocess docstring
dylanhmorris Jul 26, 2024
b298bc2
Update vars => value in docstring for DeterministicVariable
dylanhmorris Jul 26, 2024
859dbc9
Clarify relationship between t_start/t_unit of a RandomVariable and o…
dylanhmorris Jul 26, 2024
c1111fe
Update docs/source/tutorials/time.qmd
dylanhmorris Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The fundamental time unit should represent a period of fixed (or approximately f

For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit.

`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units.
`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are namedtuples with `TimeArray` objects that carry the same information.

The tuple `(t_unit, t_start)` can encode different types of time series data. For example:
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -31,14 +31,14 @@ The `PeriodicBroadcaster()` class provides a way of tiling and repeating data ac

The following section describes some preliminary design principles that may be included in future versions of `pyrenew`.

### Validation

With random variables possibly spanning different time scales, *e.g.*, weekly, daily, hourly, the metaclass `Model` should ensure random variables within the model share the same time unit.

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
```

### Retrieving time information from sites

Since numpyro only stores Jax arrays, we cannot store the time information in the arrays themselves. Next iterations of `pyrenew` should include a way to retrieve the time information from the sites of the model after running them.
18 changes: 15 additions & 3 deletions model/src/pyrenew/deterministic/deterministic.py
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, TimeArray


class DeterministicVariable(RandomVariable):
Expand All @@ -19,6 +19,8 @@ def __init__(
self,
vars: ArrayLike,
name: str,
t_start: int | None = None,
t_unit: int | None = None,
) -> None:
"""Default constructor

Expand All @@ -28,13 +30,18 @@ def __init__(
A tuple with arraylike objects.
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
name : str, optional
A name to assign to the process.
t_start : int, optional
The start time of the process.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest) time unit.

Returns
-------
None
"""

self.validate(vars)
self.set_timeseries(t_start, t_unit)
self.vars = jnp.atleast_1d(vars)
self.name = name

Expand Down Expand Up @@ -83,8 +90,13 @@ def sample(
Returns
-------
tuple
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Containing the stored values during construction.
Containing the stored values during construction wrapped in a TimeArrayß.
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
"""
if record:
npro.deterministic(self.name, self.vars)
return (self.vars,)
return (
TimeArray(
array = self.vars,
t_start=self.t_start,
t_unit=self.t_unit,
),)
18 changes: 15 additions & 3 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def __init__(
vars: ArrayLike,
name: str,
tol: float = 1e-5,
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
t_start: int | None = None,
t_unit: int | None = None,
) -> None:
"""
Default constructor
Expand All @@ -36,7 +38,12 @@ def __init__(
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-5.

t_start : int, optional
The start time of the process.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest)
time unit.

Returns
-------
None
Expand All @@ -46,7 +53,12 @@ def __init__(
tol=tol,
)

self.basevar = DeterministicVariable(vars, name)
self.basevar = DeterministicVariable(
vars=vars,
name=name,
t_start=t_start,
t_unit=t_unit,
)

return None

Expand Down Expand Up @@ -82,7 +94,7 @@ def sample(
Returns
-------
tuple
Containing the stored values during construction.
Containing the stored values during construction wrapped in a TimeArray.
"""

return self.basevar.sample(**kwargs)
Expand Down
13 changes: 7 additions & 6 deletions model/src/pyrenew/deterministic/nullrv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from jax.typing import ArrayLike
from pyrenew.metaclass import TimeArray
from pyrenew.deterministic.deterministic import DeterministicVariable


Expand Down Expand Up @@ -46,10 +47,10 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a TimeArray with None.
"""

return (None,)
return (TimeArray(None),)


class NullProcess(NullVariable):
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -95,10 +96,10 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a TimeArray with None.
"""

return (None,)
return (TimeArray(None),)


class NullObservation(NullVariable):
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -151,7 +152,7 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a TimeArray with None.
"""

return (None,)
return (TimeArray(None),)
22 changes: 17 additions & 5 deletions model/src/pyrenew/deterministic/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# numpydoc ignore=GL08

import jax.numpy as jnp
from pyrenew.metaclass import TimeArray
from pyrenew.deterministic.deterministic import DeterministicVariable


Expand Down Expand Up @@ -29,14 +30,25 @@ def sample(
Returns
-------
tuple
Containing the stored values during construction.
Containing the stored values during construction wrapped in a TimeArray.
"""

res, *_ = super().sample(**kwargs)

dif = duration - res.shape[0]
dif = duration - res.array.shape[0]

if dif > 0:
return (jnp.hstack([res, jnp.repeat(res[-1], dif)]),)

return (res[:duration],)
return (
TimeArray(
jnp.hstack([res.array, jnp.repeat(res.array[-1], dif)]), t_start=self.t_start,
t_unit=self.t_unit,
),
)

return (
TimeArray(
array=res.array[:duration],
t_start=self.t_start,
t_unit=self.t_unit
),
)
24 changes: 15 additions & 9 deletions model/src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, TimeArray


class HospitalAdmissionsSample(NamedTuple):
Expand All @@ -20,12 +20,12 @@ class HospitalAdmissionsSample(NamedTuple):
----------
infection_hosp_rate : float, optional
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
The infection-to-hospitalization rate. Defaults to None.
latent_hospital_admissions : ArrayLike or None
latent_hospital_admissions : TimeArray or None
The computed number of hospital admissions. Defaults to None.
"""

infection_hosp_rate: float | None = None
latent_hospital_admissions: ArrayLike | None = None
latent_hospital_admissions: TimeArray | None = None

def __repr__(self):
return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions})"
Expand Down Expand Up @@ -162,7 +162,7 @@ def sample(

Parameters
----------
latent : ArrayLike
latent : ArrayLike or TimeArray
Latent infections.
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
Expand All @@ -175,7 +175,7 @@ def sample(

infection_hosp_rate, *_ = self.infect_hosp_rate_rv(**kwargs)

infection_hosp_rate_t = infection_hosp_rate * latent_infections
infection_hosp_rate_t = infection_hosp_rate.array * latent_infections

(
infection_to_admission_interval,
Expand All @@ -184,25 +184,31 @@ def sample(

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate_t,
infection_to_admission_interval,
infection_to_admission_interval.array,
mode="full",
)[: infection_hosp_rate_t.shape[0]]

# Applying the day of the week effect
latent_hospital_admissions = (
latent_hospital_admissions
* self.day_of_week_effect_rv(**kwargs)[0]
* self.day_of_week_effect_rv(**kwargs)[0].array
)

# Applying probability of hospitalization effect
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
latent_hospital_admissions = (
latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0]
latent_hospital_admissions *
self.hosp_report_prob_rv(**kwargs)[0].array
)

npro.deterministic(
self.latent_hospital_admissions_varname, latent_hospital_admissions
)

return HospitalAdmissionsSample(
infection_hosp_rate, latent_hospital_admissions
infection_hosp_rate=infection_hosp_rate,
latent_hospital_admissions=TimeArray(
array=latent_hospital_admissions,
t_start=self.infection_to_admission_interval_rv.t_start,
t_unit=self.infection_to_admission_interval_rv.t_unit,
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def seed_infections(self, I_pre_seed: ArrayLike):
f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}."
)
(rate,) = self.rate()
rate = rate.array
if rate.size != 1:
raise ValueError(
f"rate must be an array of size 1. Got size {rate.size}."
Expand Down
12 changes: 9 additions & 3 deletions model/src/pyrenew/latent/infection_initialization_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
)
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, TimeArray


class InfectionInitializationProcess(RandomVariable):
Expand Down Expand Up @@ -96,7 +96,13 @@ def sample(self) -> tuple:
a tuple where the only element is an array with the number of seeded infections at each time point.
"""
(I_pre_seed,) = self.I_pre_seed_rv()
infection_seeding = self.infection_seed_method(I_pre_seed)
infection_seeding = self.infection_seed_method(I_pre_seed.array)
npro.deterministic(self.name, infection_seeding)

return (infection_seeding,)
return (
TimeArray(
array=infection_seeding,
t_start=self.t_start,
t_unit=self.t_unit,
),
)
6 changes: 3 additions & 3 deletions model/src/pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp
import pyrenew.latent.infection_functions as inf
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, TimeArray


class InfectionsSample(NamedTuple):
Expand All @@ -17,7 +17,7 @@ class InfectionsSample(NamedTuple):

Attributes
----------
post_initialization_infections : ArrayLike | None, optional
post_initialization_infections : TimeArray | None, optional
The estimated latent infections. Defaults to None.
"""

Expand Down Expand Up @@ -97,4 +97,4 @@ def sample(
reversed_generation_interval_pmf=gen_int_rev,
)

return InfectionsSample(post_initialization_infections)
return InfectionsSample(TimeArray(post_initialization_infections))
11 changes: 7 additions & 4 deletions model/src/pyrenew/latent/infectionswithfeedback.py
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import pyrenew.arrayutils as au
import pyrenew.latent.infection_functions as inf
from numpy.typing import ArrayLike
from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype
from pyrenew.metaclass import (
RandomVariable, _assert_sample_and_rtype, TimeArray
)


class InfectionsRtFeedbackSample(NamedTuple):
Expand Down Expand Up @@ -159,6 +161,7 @@ def sample(
inf_feedback_strength, *_ = self.infection_feedback_strength(
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)
inf_feedback_strength = inf_feedback_strength.array

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.size == 1:
Expand All @@ -177,7 +180,7 @@ def sample(
# Sampling inf feedback pmf
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)

inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.array)

(
post_initialization_infections,
Expand All @@ -195,6 +198,6 @@ def sample(
npro.deterministic("Rt_adjusted", Rt_adj)

return InfectionsRtFeedbackSample(
post_initialization_infections=post_initialization_infections,
rt=Rt_adj,
post_initialization_infections=TimeArray(post_initialization_infections),
rt=TimeArray(Rt_adj),
)
Loading
Loading