Skip to content

Commit

Permalink
Implements SampledValue (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon authored Jul 26, 2024
1 parent 7fd138d commit 5a3f0e2
Show file tree
Hide file tree
Showing 41 changed files with 435 additions and 216 deletions.
6 changes: 3 additions & 3 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].plot(sim_data.Rt.value)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].plot(sim_data.observed_infections.value)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
Expand All @@ -246,7 +246,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
data_observed_infections=sim_data.observed_infections.value,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Expand Down
24 changes: 14 additions & 10 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ The following code-chunk defines the model components. Notice that for both the
```{python}
# | label: model-components
gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])
gen_int = DeterministicPMF(name="gen_int", value=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.05)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.5),
DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
Expand Down Expand Up @@ -103,7 +105,7 @@ with numpyro.handlers.seed(rng_seed=223):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections)
ax.plot(model0_samp.latent_infections.value)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
Expand Down Expand Up @@ -160,7 +162,7 @@ The next step is to create the actual class. The bulk of its implementation lies
# | label: new-model-def
# | code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import arrayutils as au
from jax.typing import ArrayLike
Expand Down Expand Up @@ -208,12 +210,14 @@ class InfFeedback(RandomVariable):
**kwargs,
)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0]
x=inf_feedback_strength.value,
y=Rt,
fill_value=inf_feedback_strength.value[0],
)
# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value)
# Generating the infections with feedback
all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
Expand All @@ -230,8 +234,8 @@ class InfFeedback(RandomVariable):
# Preparing theoutput
return InfFeedbackSample(
infections=all_infections,
rt=Rt_adj,
infections=SampledValue(all_infections),
rt=SampledValue(Rt_adj),
)
```

Expand Down Expand Up @@ -273,8 +277,8 @@ Comparing `model0` with `model1`, these two should match:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].plot(model0_samp.latent_infections.value)
ax[1].plot(model1_samp.latent_infections.value)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class MyRt(metaclass.RandomVariable):
base_rv=process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=metaclass.DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, sd_rt)
name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
name="init_log_Rt_rv", dist=dist.Normal(0, 0.2)
Expand Down Expand Up @@ -272,11 +272,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(simulated_data.Rt)
axs[0].plot(simulated_data.Rt.value)
axs[0].set_ylabel("Simulated Rt")
# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o")
axs[1].set_ylabel("Simulated Admissions")
fig.suptitle("Basic renewal model")
Expand Down
6 changes: 4 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the Rt values
import matplotlib.pyplot as plt
plt.step(np.arange(len(sim_data.rt)), sim_data.rt, where="post")
plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post")
plt.xlabel("Time")
plt.ylabel("Rt")
plt.title("Simulated Rt values")
Expand Down Expand Up @@ -92,7 +92,9 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the effect values
import matplotlib.pyplot as plt
plt.step(np.arange(len(sim_data.value)), sim_data.value, where="post")
plt.step(
np.arange(len(sim_data.value.value)), sim_data.value.value, where="post"
)
plt.xlabel("Time")
plt.ylabel("Effect size")
plt.title("Simulated Day of Week Effect values")
Expand Down
19 changes: 13 additions & 6 deletions docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ The fundamental time unit should represent a period of fixed (or approximately f

For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit.

`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units.
`pyrenew` deals with time by having `RandomVariable`s carry information about

The tuple `(t_unit, t_start)` can encode different types of time series data. For example:
1. their own time unit expressed relative to the fundamental unit (`t_unit`) and
2. the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units.

Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes.

By default, `SampledValue` objects carry the `t_start` and `t_unit` of the `RandomVariable` from which they are `sample()`-d. One might override this default to allow a `RandomVariable.sample()` call to produce multiple `SampledValue`s with different time-units, or with different start-points relative to the `RandomVariable`'s own `t_start`.

The `t_unit, t_start` pair can encode different types of time series data. For example:

| Description | `t_unit` | `t_start` |
|:-----------------|----------------:|-----------------:|
Expand All @@ -31,14 +38,14 @@ The `PeriodicBroadcaster()` class provides a way of tiling and repeating data ac

The following section describes some preliminary design principles that may be included in future versions of `pyrenew`.

### Validation

With random variables possibly spanning different time scales, *e.g.*, weekly, daily, hourly, the metaclass `Model` should ensure random variables within the model share the same time unit.

### Array alignment

Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the initialization process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.:

```python
Rt_aligned, infections_aligned = align([Rt, infections])
```

### Retrieving time information from sites

Future versions of `pyrenew` could include a way to retrieve the time information for sites keyed by site name the model.
31 changes: 24 additions & 7 deletions model/src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
from pyrenew.metaclass import RandomVariable, SampledValue


class DeterministicVariable(RandomVariable):
Expand All @@ -19,24 +19,30 @@ def __init__(
self,
name: str,
value: ArrayLike,
t_start: int | None = None,
t_unit: int | None = None,
) -> None:
"""Default constructor
Parameters
----------
name : str
A name to assign to the process.
A name to assign to the variable.
value : ArrayLike
An ArrayLike object.
t_start : int, optional
The start time of the variable, if any.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest) time unit, if any
Returns
-------
None
"""

self.name = name
self.value = jnp.atleast_1d(value)
self.validate(value)
self.set_timeseries(t_start, t_unit)

return None

Expand Down Expand Up @@ -75,16 +81,27 @@ def sample(
Parameters
----------
record : bool, optional
Whether to record the value of the deterministic RandomVariable. Defaults to True.
Whether to record the value of the deterministic
RandomVariable. Defaults to True.
**kwargs : dict, optional
Additional keyword arguments passed through to internal
sample calls, should there be any.
Returns
-------
tuple
Containing the stored values during construction.
tuple[SampledValue]
A length-one tuple whose single entry is a
:class:`SampledValue`
instance with `value=self.value`,
`t_start=self.t_start`, and
`t_unit=self.t_unit`.
"""
if record:
numpyro.deterministic(self.name, self.value)
return (self.value,)
return (
SampledValue(
value=self.value,
t_start=self.t_start,
t_unit=self.t_unit,
),
)
16 changes: 14 additions & 2 deletions model/src/pyrenew/deterministic/deterministicpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def __init__(
name: str,
value: ArrayLike,
tol: float = 1e-5,
t_start: int | None = None,
t_unit: int | None = None,
) -> None:
"""
Default constructor
Expand All @@ -36,6 +38,11 @@ def __init__(
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-5.
t_start : int, optional
The start time of the process.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest)
time unit.
Returns
-------
Expand All @@ -46,7 +53,12 @@ def __init__(
tol=tol,
)

self.basevar = DeterministicVariable(name=name, value=value)
self.basevar = DeterministicVariable(
name=name,
value=value,
t_start=t_start,
t_unit=t_unit,
)

return None

Expand Down Expand Up @@ -82,7 +94,7 @@ def sample(
Returns
-------
tuple
Containing the stored values during construction.
Containing the stored values during construction wrapped in a SampledValue.
"""

return self.basevar.sample(**kwargs)
Expand Down
13 changes: 7 additions & 6 deletions model/src/pyrenew/deterministic/nullrv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jax.typing import ArrayLike
from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.metaclass import SampledValue


class NullVariable(DeterministicVariable):
Expand Down Expand Up @@ -46,10 +47,10 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a SampledValue with None.
"""

return (None,)
return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),)


class NullProcess(NullVariable):
Expand Down Expand Up @@ -95,10 +96,10 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a SampledValue with None.
"""

return (None,)
return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),)


class NullObservation(NullVariable):
Expand Down Expand Up @@ -148,7 +149,7 @@ def sample(
Returns
-------
tuple
Containing None.
Containing a SampledValue with None.
"""

return (None,)
return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),)
20 changes: 15 additions & 5 deletions model/src/pyrenew/deterministic/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax.numpy as jnp
from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.metaclass import SampledValue


class DeterministicProcess(DeterministicVariable):
Expand All @@ -28,15 +29,24 @@ def sample(
Returns
-------
tuple
Containing the stored values during construction.
tuple[SampledValue]
containing the deterministic value(s) provided
at construction as a series of length `duration`.
"""

res, *_ = super().sample(**kwargs)

dif = duration - res.shape[0]
dif = duration - res.value.shape[0]

if dif > 0:
return (jnp.hstack([res, jnp.repeat(res[-1], dif)]),)
value = jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)])
else:
value = res.value[:duration]

return (res[:duration],)
res = SampledValue(
value,
t_start=self.t_start,
t_unit=self.t_unit,
)

return (res,)
Loading

0 comments on commit 5a3f0e2

Please sign in to comment.