From f0a35ee36dc2ec4ffa7488f0f48ab77c87520876 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Wed, 5 Jun 2024 10:32:48 -0600 Subject: [PATCH] Renaming datautils to arrayutils (#154) --- model/docs/extending_pyrenew.qmd | 6 +++--- .../src/pyrenew/{datautils.py => arrayutils.py} | 2 +- .../src/pyrenew/latent/infectionswithfeedback.py | 4 ++-- .../pyrenew/model/rtinfectionsrenewalmodel.py | 4 ++-- model/src/test/test_datautils.py | 16 ++++++++-------- 5 files changed, 16 insertions(+), 16 deletions(-) rename model/src/pyrenew/{datautils.py => arrayutils.py} (97%) diff --git a/model/docs/extending_pyrenew.qmd b/model/docs/extending_pyrenew.qmd index 02919f0a..ac9a5938 100644 --- a/model/docs/extending_pyrenew.qmd +++ b/model/docs/extending_pyrenew.qmd @@ -127,7 +127,7 @@ InfFeedbackSample = namedtuple( ) ``` -The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.datautils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: +The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: ```{python} #| label: new-model-def @@ -135,7 +135,7 @@ The next step is to create the actual class. The bulk of its implementation lies # Creating the class from pyrenew.metaclass import RandomVariable from pyrenew.latent import compute_infections_from_rt_with_feedback -from pyrenew import datautils as du +from pyrenew import arrayutils as au from jax.typing import ArrayLike import jax.numpy as jnp @@ -181,7 +181,7 @@ class InfFeedback(RandomVariable): inf_feedback_strength, *_ = self.infection_feedback_strength.sample( **kwargs, ) - inf_feedback_strength = du.pad_x_to_match_y( + inf_feedback_strength = au.pad_x_to_match_y( x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0] ) diff --git a/model/src/pyrenew/datautils.py b/model/src/pyrenew/arrayutils.py similarity index 97% rename from model/src/pyrenew/datautils.py rename to model/src/pyrenew/arrayutils.py index fe7b6944..552183af 100644 --- a/model/src/pyrenew/datautils.py +++ b/model/src/pyrenew/arrayutils.py @@ -1,5 +1,5 @@ """ -Utility functions for data processing. +Utility functions for processing arrays. """ import jax.numpy as jnp diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 5e65a2e2..d1947a3a 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import numpyro as npro -import pyrenew.datautils as du +import pyrenew.arrayutils as au import pyrenew.latent.infection_functions as inf from numpy.typing import ArrayLike from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype @@ -160,7 +160,7 @@ def sample( # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: - inf_feedback_strength = du.pad_x_to_match_y( + inf_feedback_strength = au.pad_x_to_match_y( x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0], diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 4d161d42..e8223029 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -6,7 +6,7 @@ from typing import NamedTuple import jax.numpy as jnp -import pyrenew.datautils as du +import pyrenew.arrayutils as au from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype @@ -291,7 +291,7 @@ def sample( # is merged. # SEE ALSO: # https://github.com/CDCgov/multisignal-epi-inference/pull/123#discussion_r1612337288 - i0 = du.pad_x_to_match_y(x=i0, y=gen_int, fill_value=0.0) + i0 = au.pad_x_to_match_y(x=i0, y=gen_int, fill_value=0.0) # Sampling from the latent process latent, *_ = self.sample_infections_latent( diff --git a/model/src/test/test_datautils.py b/model/src/test/test_datautils.py index 860c5f60..8ee35b68 100644 --- a/model/src/test/test_datautils.py +++ b/model/src/test/test_datautils.py @@ -1,13 +1,13 @@ """ -Tests for the datautils module. +Tests for the arrayutils module. """ import jax.numpy as jnp -import pyrenew.datautils as du +import pyrenew.arrayutils as au import pytest -def test_datautils_pad_to_match(): +def test_arrayutils_pad_to_match(): """ Verifies extension when required and error when `fix_y` is True. """ @@ -15,7 +15,7 @@ def test_datautils_pad_to_match(): x = jnp.array([1, 2, 3]) y = jnp.array([1, 2]) - x_pad, y_pad = du.pad_to_match(x, y) + x_pad, y_pad = au.pad_to_match(x, y) assert x_pad.size == y_pad.size assert x_pad.size == 3 @@ -23,7 +23,7 @@ def test_datautils_pad_to_match(): x = jnp.array([1, 2]) y = jnp.array([1, 2, 3]) - x_pad, y_pad = du.pad_to_match(x, y) + x_pad, y_pad = au.pad_to_match(x, y) assert x_pad.size == y_pad.size assert x_pad.size == 3 @@ -33,10 +33,10 @@ def test_datautils_pad_to_match(): # Verify that the function raises an error when `fix_y` is True with pytest.raises(ValueError): - x_pad, y_pad = du.pad_to_match(x, y, fix_y=True) + x_pad, y_pad = au.pad_to_match(x, y, fix_y=True) -def test_datautils_pad_x_to_match_y(): +def test_arrayutils_pad_x_to_match_y(): """ Verifies extension when required """ @@ -44,6 +44,6 @@ def test_datautils_pad_x_to_match_y(): x = jnp.array([1, 2]) y = jnp.array([1, 2, 3]) - x_pad = du.pad_x_to_match_y(x, y) + x_pad = au.pad_x_to_match_y(x, y) assert x_pad.size == 3