Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring the periodic broadcaster and added warn + error + test #349

Merged
merged 15 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ for i in range(0, 30, 7):
plt.show()
```

The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `PeriodicBroadcaster` to repeating values: `PeriodicBroadcaster(..., period_size=7, broadcast_type="repeat")`. Setting the `broadcast_type` to `"repeat"` repeats each vector element for the specified period size. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.
The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.

## Repeated sequences (tiling)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/time.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The `t_unit, t_start` pair can encode different types of time series data. For e

## How it relates to periodicity

The `PeriodicBroadcaster()` class provides a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.
The `tile_until_n()` and `repeat_until_n()` functions provide a way of tiling and repeating data accounting starting time, but it does not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model.
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

## Unimplemented features

Expand Down
231 changes: 102 additions & 129 deletions model/src/pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,138 +112,111 @@ def __repr__(self):
return f"PeriodicProcessSample(value={self.value})"


class PeriodicBroadcaster:
r"""
Broadcast arrays periodically using either repeat or tile,
considering period size and starting point.
def tile_until_n(
data: ArrayLike,
n_timepoints: int,
offset: int | None = 0,
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
) -> ArrayLike:
"""
Tile the data until it reaches `n_timepoints`.

def __init__(
self,
offset: int,
period_size: int,
broadcast_type: str,
) -> None:
"""
Default constructor for PeriodicBroadcaster class.

Parameters
----------
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
period_size : int
Size of the period.
broadcast_type : str
Type of broadcasting to use, either "repeat" or "tile".

Notes
-----
See the sample method for more information on the broadcasting types.

Returns
-------
None
"""

self.validate(
offset=offset,
period_size=period_size,
broadcast_type=broadcast_type,
)
Parameters
----------
data : ArrayLike
Data to broadcast.
n_timepoints : int
Duration of the sequence.
offset : int, optional
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Relative point at which data starts, must be a non-negative integer.

Notes
-----
Using the `offset` parameter, the function will start the broadcast
from the `offset`-th element of the data. If the data is shorter than
`n_timepoints`, the function will repeat or tile the data until it
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
reaches `n_timepoints`.

self.period_size = period_size
self.offset = offset
self.broadcast_type = broadcast_type

return None

@staticmethod
def validate(offset: int, period_size: int, broadcast_type: str) -> None:
"""
Validate the input parameters.

Parameters
----------
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
period_size : int
Size of the period.
broadcast_type : str
Type of broadcasting to use, either "repeat" or "tile".

Returns
-------
None
"""

# Period size should be a positive integer
assert isinstance(
period_size, int
), f"period_size should be an integer. It is {type(period_size)}."

assert (
period_size > 0
), f"period_size should be a positive integer. It is {period_size}."

# Data starts should be a positive integer
assert isinstance(
offset, int
), f"offset should be an integer. It is {type(offset)}."

assert (
0 <= offset
), f"offset should be a positive integer. It is {offset}."

assert offset <= period_size - 1, (
"offset should be less than or equal to period_size - 1."
f"It is {offset}. It should be less than or equal "
f"to {period_size - 1}."
)
Returns
-------
ArrayLike
Tiled data.
"""

# Data starts should be a positive integer
assert isinstance(
offset, int
), f"offset should be an integer. It is {type(offset)}."

assert 0 <= offset, f"offset should be a positive integer. It is {offset}."

return jnp.tile(data, (n_timepoints // data.size) + 1)[
offset : (offset + n_timepoints)
]


def repeat_until_n(
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
data: ArrayLike,
n_timepoints: int,
period_size: int,
offset: int | None = 0,
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Repeat the data until it reaches `n_timepoints`.
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
Using the `offset` parameter, the function will start the broadcast
from the `offset`-th element of the data. If the data is shorter than
`n_timepoints`, the function will repeat or tile the data until it
reaches `n_timepoints`.

# Broadcast type should be either "repeat" or "tile"
assert broadcast_type in ["repeat", "tile"], (
"broadcast_type should be either 'repeat' or 'tile'. "
f"It is {broadcast_type}."
Parameters
----------
data : ArrayLike
Data to broadcast.
n_timepoints : int
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Duration of the sequence.
period_size : int
Size of the period for the repeat broadcast.
offset : int, optional
Relative point at which data starts, must be between 0 and
period_size - 1. By default 0.

Returns
-------
ArrayLike
Repeated data.
"""

# Data starts should be a positive integer
assert isinstance(
offset, int
), f"offset should be an integer. It is {type(offset)}."

assert 0 <= offset, f"offset should be a positive integer. It is {offset}."

# Period size should be a positive integer
assert isinstance(
period_size, int
), f"period_size should be an integer. It is {type(period_size)}."

assert (
period_size > 0
), f"period_size should be a positive integer. It is {period_size}."

assert offset <= period_size - 1, (
"offset should be less than or equal to period_size - 1."
f"It is {offset}. It should be less than or equal "
f"to {period_size - 1}."
)

if (data.size * period_size) < n_timepoints:
raise ValueError(
"The data is too short to broadcast to "
f"the given number of timepoints ({n_timepoints}). The "
"repeated data would have a size of data.size * "
f"period_size = {data.size} * {period_size} = "
f"{data.size * period_size}."
)

return None

def __call__(
self,
data: ArrayLike,
n_timepoints: int,
) -> ArrayLike:
"""
Broadcast the data to the given number of timepoints
considering the period size and starting point.

Parameters
----------
data: ArrayLike
Data to broadcast.
n_timepoints : int
Duration of the sequence.

Notes
-----
The broadcasting is done by repeating or tiling the data. When
self.broadcast_type = "repeat", the function will repeat each value of the data `self.period_size` times until it reaches `n_timepoints`. When self.broadcast_type = "tile", the function will tile the data until it reaches `n_timepoints`.

Using the `offset` parameter, the function will start the broadcast from the `offset`-th element of the data. If the data is shorter than `n_timepoints`, the function will repeat or tile the data until it reaches `n_timepoints`.

Returns
-------
ArrayLike
Broadcasted array.
"""

if self.broadcast_type == "repeat":
return jnp.repeat(data, self.period_size)[
self.offset : (self.offset + n_timepoints)
]
elif self.broadcast_type == "tile":
return jnp.tile(
data, int(jnp.ceil(n_timepoints / self.period_size))
)[self.offset : (self.offset + n_timepoints)]
return jnp.repeat(data, period_size)[offset : (offset + n_timepoints)]
13 changes: 3 additions & 10 deletions model/src/pyrenew/process/periodiceffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class PeriodicEffect(RandomVariable):
def __init__(
self,
offset: int,
period_size: int,
quantity_to_broadcast: RandomVariable,
t_start: int,
t_unit: int,
Expand All @@ -48,8 +47,6 @@ def __init__(
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
period_size : int
Size of the period.
quantity_to_broadcast : RandomVariable
Values to be broadcasted (repeated or tiled).
t_start : int
Expand All @@ -64,11 +61,7 @@ def __init__(

PeriodicEffect.validate(quantity_to_broadcast)

self.broadcaster = au.PeriodicBroadcaster(
offset=offset,
period_size=period_size,
broadcast_type="tile",
)
self.offset = offset

self.set_timeseries(
t_start=t_start,
Expand Down Expand Up @@ -114,9 +107,10 @@ def sample(self, duration: int, **kwargs):

return PeriodicEffectSample(
value=SampledValue(
self.broadcaster(
au.tile_until_n(
data=self.quantity_to_broadcast.sample(**kwargs)[0].value,
n_timepoints=duration,
offset=self.offset,
),
t_start=self.t_start,
t_unit=self.t_unit,
Expand Down Expand Up @@ -157,7 +151,6 @@ def __init__(

super().__init__(
offset=offset,
period_size=7,
quantity_to_broadcast=quantity_to_broadcast,
t_start=t_start,
t_unit=1,
Expand Down
16 changes: 8 additions & 8 deletions model/src/pyrenew/process/rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.arrayutils as au
from jax.typing import ArrayLike
from pyrenew.arrayutils import PeriodicBroadcaster
from pyrenew.metaclass import (
RandomVariable,
SampledValue,
Expand Down Expand Up @@ -77,19 +77,14 @@ def __init__(
-------
None
"""
self.name = name
self.broadcaster = PeriodicBroadcaster(
offset=offset,
period_size=period_size,
broadcast_type="repeat",
)

self.validate(
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

self.name = name
self.period_size = period_size
self.offset = offset
self.log_rt_rv = log_rt_rv
Expand Down Expand Up @@ -192,7 +187,12 @@ def sample(

return RtPeriodicDiffProcessSample(
rt=SampledValue(
self.broadcaster(jnp.exp(log_rt.value.flatten()), duration),
au.repeat_until_n(
data=jnp.exp(log_rt.value.flatten()),
n_timepoints=duration,
offset=self.offset,
period_size=self.period_size,
),
t_start=self.t_start,
t_unit=self.t_unit,
),
Expand Down
40 changes: 40 additions & 0 deletions model/src/test/test_broadcaster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Test the broadcaster utility
"""

import jax.numpy as jnp
import numpy.testing as testing
import pytest
from pyrenew.arrayutils import repeat_until_n, tile_until_n


def test_broadcaster() -> None:
"""
Test the PeriodicBroadcaster utility.
"""
base_array = jnp.array([1, 2, 3])

testing.assert_array_equal(
tile_until_n(base_array, 10),
jnp.array([1, 2, 3, 1, 2, 3, 1, 2, 3, 1]),
)

testing.assert_array_equal(
repeat_until_n(
data=base_array,
n_timepoints=10,
offset=0,
period_size=7,
),
jnp.array([1, 1, 1, 1, 1, 1, 1, 2, 2, 2]),
)

with pytest.raises(ValueError, match="The data is too short to broadcast"):
repeat_until_n(
data=base_array,
n_timepoints=100,
offset=0,
period_size=7,
)

return None
Loading