diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 2bb03333..2b626ef2 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -40,8 +40,8 @@ def __init__( None """ self.name = name - self.value = jnp.atleast_1d(value) self.validate(value) + self.value = jnp.atleast_1d(value) self.set_timeseries(t_start, t_unit) return None @@ -49,7 +49,7 @@ def __init__( @staticmethod def validate(value: ArrayLike) -> None: """ - Validates input to DeterministicPMF + Validates input to DeterministicVariable Parameters ---------- @@ -63,10 +63,14 @@ def validate(value: ArrayLike) -> None: Raises ------ Exception - If the input value object is not a ArrayLike. + If the input value object is not an ArrayLike object. """ if not isinstance(value, ArrayLike): - raise Exception("value is not a ArrayLike") + raise ValueError( + f"value {value} passed to a DeterministicVariable " + f"is of type {type(value).__name__}, expected " + "an ArrayLike object" + ) return None diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index 6e72f8cb..39e08e65 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -1,7 +1,11 @@ # numpydoc ignore=GL08 +import re + import jax.numpy as jnp +import numpy as np import numpy.testing as testing +import pytest from pyrenew.deterministic import ( DeterministicPMF, DeterministicProcess, @@ -62,3 +66,49 @@ def test_deterministic(): testing.assert_equal(var4()[0].value, None) testing.assert_equal(var5(duration=1)[0].value, None) + + +def test_deterministic_validation(): + """ + Check that validation methods for DeterministicVariable + work as expected. + """ + # validation should fail on construction + some_non_array_likes = [ + {"a": jnp.array([1, 2.5, 3])}, + # a valid pytree, but not an arraylike + "a string", + ] + some_array_likes = [ + 5, + -3.023523, + np.array([1, 3.32, 5]), + jnp.array([-32, 23]), + jnp.array(-32), + np.array(5), + ] + + for non_arraylike_val in some_non_array_likes: + matchval = re.escape( + f"value {non_arraylike_val} passed to a " + "DeterministicVariable is of type " + f"{type(non_arraylike_val).__name__}, expected " + "an ArrayLike object" + ) + + with pytest.raises(ValueError, match=matchval): + # the class's validation function itself + # should raise an error when passed a + # non arraylike value + DeterministicVariable.validate(non_arraylike_val) + + with pytest.raises(ValueError, match=matchval): + # validation should fail on constructor call + DeterministicVariable( + value=non_arraylike_val, name="invalid_variable" + ) + + # validation should succeed with ArrayLike + for arraylike_val in some_array_likes: + DeterministicVariable.validate(arraylike_val) + DeterministicVariable(value=arraylike_val, name="valid_variable")