From 07d449cd6152de7068b62e80083e539f1701d4c8 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 29 Jul 2024 18:03:21 -0700 Subject: [PATCH 1/3] Add new tests for deterministicvariable --- model/src/test/test_deterministic.py | 45 ++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index 6e72f8cb..b371695c 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -1,7 +1,9 @@ # numpydoc ignore=GL08 import jax.numpy as jnp +import numpy as np import numpy.testing as testing +import pytest from pyrenew.deterministic import ( DeterministicPMF, DeterministicProcess, @@ -62,3 +64,46 @@ def test_deterministic(): testing.assert_equal(var4()[0].value, None) testing.assert_equal(var5(duration=1)[0].value, None) + + +def test_deterministic_validation(): + """ + Check that validation methods for DeterministicVariable + work as expected. + """ + # validation should fail on construction + some_non_array_likes = [ + {"a": jnp.array([1, 2.5, 3])}, + # a valid pytree, but not an arraylike + "a string", + ] + some_array_likes = [ + 5, + -3.023523, + np.array([1, 3.32, 5]), + jnp.array([-32, 23]), + jnp.array(-32), + np.array(5), + ] + + for non_arraylike_val in some_non_array_likes: + with pytest.raises( + ValueError, match="value is not an ArrayLike object" + ): + # the class's validation function itself + # should raise an error when passed a + # non arraylike value + DeterministicVariable.validate(non_arraylike_val) + + with pytest.raises( + ValueError, match="value is not an ArrayLike object" + ): + # validation should fail on constructor call + DeterministicVariable( + value=non_arraylike_val, name="invalid_variable" + ) + + # validation should succeed with ArrayLike + for arraylike_val in some_array_likes: + DeterministicVariable.validate(arraylike_val) + DeterministicVariable(value=arraylike_val, name="valid_variable") From 78f7aad522199bf521a02d6864c148f0db6c8902 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 29 Jul 2024 18:05:04 -0700 Subject: [PATCH 2/3] Fix docs, error classes, and order of validation in deterministic.py --- model/src/pyrenew/deterministic/deterministic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 2bb03333..526d68bc 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -40,8 +40,8 @@ def __init__( None """ self.name = name - self.value = jnp.atleast_1d(value) self.validate(value) + self.value = jnp.atleast_1d(value) self.set_timeseries(t_start, t_unit) return None @@ -49,7 +49,7 @@ def __init__( @staticmethod def validate(value: ArrayLike) -> None: """ - Validates input to DeterministicPMF + Validates input to DeterministicVariable Parameters ---------- @@ -63,10 +63,10 @@ def validate(value: ArrayLike) -> None: Raises ------ Exception - If the input value object is not a ArrayLike. + If the input value object is not an ArrayLike object. """ if not isinstance(value, ArrayLike): - raise Exception("value is not a ArrayLike") + raise ValueError("value is not an ArrayLike object") return None From 73711ce4d0b5400ebff069cf5ae29d390eda8493 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 30 Jul 2024 12:19:57 -0400 Subject: [PATCH 3/3] More meaningful error message and test for it --- .../src/pyrenew/deterministic/deterministic.py | 6 +++++- model/src/test/test_deterministic.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 526d68bc..2b626ef2 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -66,7 +66,11 @@ def validate(value: ArrayLike) -> None: If the input value object is not an ArrayLike object. """ if not isinstance(value, ArrayLike): - raise ValueError("value is not an ArrayLike object") + raise ValueError( + f"value {value} passed to a DeterministicVariable " + f"is of type {type(value).__name__}, expected " + "an ArrayLike object" + ) return None diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index b371695c..39e08e65 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -1,5 +1,7 @@ # numpydoc ignore=GL08 +import re + import jax.numpy as jnp import numpy as np import numpy.testing as testing @@ -87,17 +89,20 @@ def test_deterministic_validation(): ] for non_arraylike_val in some_non_array_likes: - with pytest.raises( - ValueError, match="value is not an ArrayLike object" - ): + matchval = re.escape( + f"value {non_arraylike_val} passed to a " + "DeterministicVariable is of type " + f"{type(non_arraylike_val).__name__}, expected " + "an ArrayLike object" + ) + + with pytest.raises(ValueError, match=matchval): # the class's validation function itself # should raise an error when passed a # non arraylike value DeterministicVariable.validate(non_arraylike_val) - with pytest.raises( - ValueError, match="value is not an ArrayLike object" - ): + with pytest.raises(ValueError, match=matchval): # validation should fail on constructor call DeterministicVariable( value=non_arraylike_val, name="invalid_variable"