Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements SampledValue #262

Merged
merged 54 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
41dc41b
Adding information about future features for retrieving timeinfo
gvegayon Jul 10, 2024
c4eef20
Updating to return TimeArray (WIP)
gvegayon Jul 10, 2024
997812c
Working on tests (still 16 fails) [skip ci] expected to fail
gvegayon Jul 10, 2024
a9fbb18
Merge branch 'main' into implements-timearray
gvegayon Jul 11, 2024
752eed2
Merge branch 'main' into implements-timearray
gvegayon Jul 11, 2024
0ddd03e
Down to 5 errors
gvegayon Jul 12, 2024
015e381
Down to 1 error
gvegayon Jul 12, 2024
e073294
Fixing test. Next: merge conflicts
gvegayon Jul 17, 2024
3d7abe1
Fixing merge conflicts. Tests pass
gvegayon Jul 17, 2024
fb85885
Typo [skip ci]
gvegayon Jul 17, 2024
68f0867
Fixing tutorials
gvegayon Jul 17, 2024
813ffdf
Merge branch 'implements-timearray' of github.com:CDCgov/multisignal-…
gvegayon Jul 17, 2024
eba63e9
Working on hosp admin tutorial (expected to fail)
gvegayon Jul 17, 2024
b6c0ff0
Patching tutorial
gvegayon Jul 18, 2024
69d56c1
Making pre-commit happy
gvegayon Jul 18, 2024
1ec391e
Merge branch 'main' into implements-timearray
gvegayon Jul 18, 2024
0fad7e4
Addressing points by @damonbayer and @dylanhmorris
gvegayon Jul 22, 2024
3dff8a5
Merge branch 'implements-timearray' of github.com:CDCgov/multisignal-…
gvegayon Jul 22, 2024
6952683
Renaming TimeArray to SampledValue
gvegayon Jul 23, 2024
e4a13ce
Accessing values instead of arrays in SampledValue
gvegayon Jul 23, 2024
7213bd5
Finalizing merge
gvegayon Jul 23, 2024
6d44a70
Making pre-commit happy
gvegayon Jul 23, 2024
2bba85a
Fixing tutorial
gvegayon Jul 23, 2024
40ccf1e
Merge branch 'main' into implements-timearray
gvegayon Jul 24, 2024
2220fb0
Fixing last merge
gvegayon Jul 24, 2024
cd70eb7
Merge branch 'main' into implements-timearray
gvegayon Jul 24, 2024
a209853
Addressing comments by @damonbayer.
gvegayon Jul 24, 2024
68f1d22
Making pre-commit happy
gvegayon Jul 24, 2024
b5fc9d6
SampledValues are an instance of NamedTuple
gvegayon Jul 24, 2024
ff258d3
Adding suggestions by @dylanhmorris
gvegayon Jul 24, 2024
771f6b5
Missing comment
gvegayon Jul 24, 2024
7a8bf3b
Making pre-commit happy
gvegayon Jul 24, 2024
460993a
Fixing infections with feedback
gvegayon Jul 24, 2024
8f807dc
Removing explicit test call
gvegayon Jul 24, 2024
522c642
Merge branch 'main' into implements-timearray
dylanhmorris Jul 25, 2024
fc17488
Fix pre-commit issues and remaining numpyro/npro conflicts
dylanhmorris Jul 25, 2024
353fbc2
Addressing @dylanhmorris' comments on defaults for t_start and t_unit
gvegayon Jul 25, 2024
f62b9f4
Fixing merge
gvegayon Jul 25, 2024
b9a87f8
Forgot to remove a test call
gvegayon Jul 25, 2024
b34dca6
Merge branch 'main' into implements-timearray
gvegayon Jul 25, 2024
99c9c3b
Merge branch 'main' into implements-timearray
dylanhmorris Jul 25, 2024
6026b19
Autoformat files, fix typo caught by typos hook
dylanhmorris Jul 25, 2024
ee817e5
Update model/src/pyrenew/latent/infectionswithfeedback.py
dylanhmorris Jul 25, 2024
302ca2e
Fix typo in infectionswithfeedback.py that caused ill-formed code and…
dylanhmorris Jul 25, 2024
32cb24d
Fix tutorial bug introduced in merge conflict resolution
dylanhmorris Jul 25, 2024
43012ad
Equality assertion => almost equality assertion
dylanhmorris Jul 25, 2024
4173b75
Set identical seeds
dylanhmorris Jul 25, 2024
f635c0c
Fix another tutorial typo
dylanhmorris Jul 25, 2024
ee9dbe2
Update model/src/pyrenew/deterministic/process.py
dylanhmorris Jul 26, 2024
2177215
Fix type hinting for HospitalAdmissionsSample
dylanhmorris Jul 26, 2024
9ff9dea
Update determinsiticprocess docstring
dylanhmorris Jul 26, 2024
b298bc2
Update vars => value in docstring for DeterministicVariable
dylanhmorris Jul 26, 2024
859dbc9
Clarify relationship between t_start/t_unit of a RandomVariable and o…
dylanhmorris Jul 26, 2024
c1111fe
Update docs/source/tutorials/time.qmd
dylanhmorris Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].plot(sim_data.Rt.value)
axs[0].set_ylabel("Rt")

# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].plot(sim_data.observed_infections.value)
axs[1].set_ylabel("Infections")

fig.suptitle("Basic renewal model")
Expand All @@ -246,7 +246,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
data_observed_infections=sim_data.observed_infections.value,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Expand Down
24 changes: 14 additions & 10 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ The following code-chunk defines the model components. Notice that for both the
```{python}
# | label: model-components
gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])

gen_int = DeterministicPMF(name="gen_int", value=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.05)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)


I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.5),
DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
Expand Down Expand Up @@ -103,7 +105,7 @@ with numpyro.handlers.seed(rng_seed=223):
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections)
ax.plot(model0_samp.latent_infections.value)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
Expand Down Expand Up @@ -160,7 +162,7 @@ The next step is to create the actual class. The bulk of its implementation lies
# | label: new-model-def
# | code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import arrayutils as au
from jax.typing import ArrayLike
Expand Down Expand Up @@ -208,12 +210,14 @@ class InfFeedback(RandomVariable):
**kwargs,
)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0]
x=inf_feedback_strength.value,
y=Rt,
fill_value=inf_feedback_strength.value[0],
)

# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value)

# Generating the infections with feedback
all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
Expand All @@ -230,8 +234,8 @@ class InfFeedback(RandomVariable):
# Preparing theoutput

return InfFeedbackSample(
infections=all_infections,
rt=Rt_adj,
infections=SampledValue(all_infections),
rt=SampledValue(Rt_adj),
)
```

Expand Down Expand Up @@ -273,8 +277,8 @@ Comparing `model0` with `model1`, these two should match:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].plot(model0_samp.latent_infections.value)
ax[1].plot(model1_samp.latent_infections.value)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class MyRt(metaclass.RandomVariable):
base_rv=process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=metaclass.DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, sd_rt)
name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
Expand Down Expand Up @@ -272,11 +272,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(simulated_data.Rt)
axs[0].plot(simulated_data.Rt.value)
axs[0].set_ylabel("Simulated Rt")

# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o")
axs[1].set_ylabel("Simulated Admissions")

fig.suptitle("Basic renewal model")
Expand Down
6 changes: 4 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the Rt values
import matplotlib.pyplot as plt

plt.step(np.arange(len(sim_data.rt)), sim_data.rt, where="post")
plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post")
plt.xlabel("Time")
plt.ylabel("Rt")
plt.title("Simulated Rt values")
Expand Down Expand Up @@ -92,7 +92,9 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the effect values
import matplotlib.pyplot as plt

plt.step(np.arange(len(sim_data.value)), sim_data.value, where="post")
plt.step(
np.arange(len(sim_data.value.value)), sim_data.value.value, where="post"
)
plt.xlabel("Time")
plt.ylabel("Effect size")
plt.title("Simulated Day of Week Effect values")
Expand Down
12 changes: 6 additions & 6 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ The fundamental time unit should represent a period of fixed (or approximately f

For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit.

`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units.
`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. Each such `SampledValue` is optionally time-aware with specifiable `t_start` and `t_unit` attributes.
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved

The tuple `(t_unit, t_start)` can encode different types of time series data. For example:
The `t_unit, t_start` pair can encode different types of time series data. For example:

| Description | `t_unit` | `t_start` |
|:-----------------|----------------:|-----------------:|
Expand All @@ -31,14 +31,14 @@ The `PeriodicBroadcaster()` class provides a way of tiling and repeating data ac

The following section describes some preliminary design principles that may be included in future versions of `pyrenew`.

### Validation

With random variables possibly spanning different time scales, *e.g.*, weekly, daily, hourly, the metaclass `Model` should ensure random variables within the model share the same time unit.

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the initialization process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
```

### Retrieving time information from sites

Future versions of `pyrenew` could include a way to retrieve the time information for sites keyed by site name the model.
31 changes: 24 additions & 7 deletions model/src/pyrenew/deterministic/deterministic.py
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, SampledValue


class DeterministicVariable(RandomVariable):
Expand All @@ -19,24 +19,30 @@ def __init__(
self,
name: str,
value: ArrayLike,
t_start: int | None = None,
t_unit: int | None = None,
) -> None:
"""Default constructor

Parameters
----------
name : str
A name to assign to the process.
A name to assign to the variable.
value : ArrayLike
An ArrayLike object.
t_start : int, optional
The start time of the variable, if any.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest) time unit, if any

Returns
-------
None
"""

self.name = name
self.value = jnp.atleast_1d(value)
self.validate(value)
self.set_timeseries(t_start, t_unit)

return None

Expand Down Expand Up @@ -75,16 +81,27 @@ def sample(
Parameters
----------
record : bool, optional
Whether to record the value of the deterministic RandomVariable. Defaults to True.
Whether to record the value of the deterministic
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
RandomVariable. Defaults to True.
**kwargs : dict, optional
Additional keyword arguments passed through to internal
sample calls, should there be any.

Returns
-------
tuple
Containing the stored values during construction.
tuple[SampledValue]
A length-one tuple whose single entry is a
:class:`SampledValue`
instance with `value=self.vars`,
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
`t_start=self.t_start`, and
`t_unit=self.t_unit`.
"""
if record:
numpyro.deterministic(self.name, self.value)
return (self.value,)
return (
SampledValue(
value=self.value,
t_start=self.t_start,
t_unit=self.t_unit,
),
)
16 changes: 14 additions & 2 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def __init__(
name: str,
value: ArrayLike,
tol: float = 1e-5,
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
t_start: int | None = None,
t_unit: int | None = None,
) -> None:
"""
Default constructor
Expand All @@ -36,6 +38,11 @@ def __init__(
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-5.
t_start : int, optional
The start time of the process.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest)
time unit.

Returns
-------
Expand All @@ -46,7 +53,12 @@ def __init__(
tol=tol,
)

self.basevar = DeterministicVariable(name=name, value=value)
self.basevar = DeterministicVariable(
name=name,
value=value,
t_start=t_start,
t_unit=t_unit,
)
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

return None

Expand Down Expand Up @@ -82,7 +94,7 @@ def sample(
Returns
-------
tuple
Containing the stored values during construction.
Containing the stored values during construction wrapped in a SampledValue.
"""

return self.basevar.sample(**kwargs)
Expand Down
13 changes: 7 additions & 6 deletions model/src/pyrenew/deterministic/nullrv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jax.typing import ArrayLike
from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.metaclass import SampledValue


class NullVariable(DeterministicVariable):
Expand Down Expand Up @@ -46,10 +47,10 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a SampledValue with None.
"""

return (None,)
return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),)


class NullProcess(NullVariable):
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -95,10 +96,10 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a SampledValue with None.
"""

return (None,)
return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),)


class NullObservation(NullVariable):
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -148,7 +149,7 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a SampledValue with None.
"""

return (None,)
return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),)
25 changes: 20 additions & 5 deletions model/src/pyrenew/deterministic/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax.numpy as jnp
from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.metaclass import SampledValue


class DeterministicProcess(DeterministicVariable):
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -29,14 +30,28 @@ def sample(
Returns
-------
tuple
Containing the stored values during construction.
Containing the stored values during construction wrapped in a SampledValue.
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
"""

res, *_ = super().sample(**kwargs)

dif = duration - res.shape[0]
dif = duration - res.value.shape[0]

if dif > 0:
return (jnp.hstack([res, jnp.repeat(res[-1], dif)]),)

return (res[:duration],)
res = (
SampledValue(
jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]),
t_start=self.t_start,
t_unit=self.t_unit,
),
)
else:
res = (
SampledValue(
value=res.value[:duration],
t_start=self.t_start,
t_unit=self.t_unit,
),
)

return res
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
Loading