diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index 9ce5e941..16a29568 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -58,7 +58,8 @@ def sample( name : str, optional Name of the parameter passed to numpyro.sample. **kwargs : dict, optional - Ignored. + Additional keyword arguments passed through to internal sample() + calls, should there be any. Returns ------- @@ -85,4 +86,7 @@ def _ar_scanner(carry, next): @staticmethod def validate(): + """ + Validates inputted parameters, implementation pending. + """ return None diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index 7f94b966..84d32edb 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -54,7 +54,8 @@ def sample( name : str, optional Passed to ARProcess.sample(), by default "trend_rw" **kwargs : dict, optional - Ignored. + Additional keyword arguments passed through to internal sample() + calls, should there be any. Returns ------- @@ -67,4 +68,7 @@ def sample( @staticmethod def validate(): + """ + Validates inputted parameters, implementation pending. + """ return None diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py index bff8a460..c6cb32c9 100644 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ b/model/src/pyrenew/process/rtrandomwalk.py @@ -23,7 +23,7 @@ def __init__( Parameters ---------- Rt0_dist : dist.Distribution, optional - Initial distributiono of Rt, defaults to + Initial distribution of Rt, defaults to dist.TruncatedNormal( loc=1.2, scale=0.2, low=0 ) Rt_transform : AbstractTransform, optional Transformation applied to the sampled Rt0, defaults @@ -44,7 +44,35 @@ def __init__( return None @staticmethod - def validate(Rt0_dist, Rt_transform, Rt_rw_dist): + def validate( + Rt0_dist: dist.Distribution, + Rt_transform: AbstractTransform, + Rt_rw_dist: dist.Distribution, + ) -> None: + """ + Validates Rt0_dist, Rt_transform, and Rt_rw_dist. + + Parameters + ---------- + Rt0_dist : dist.Distribution, optional + Initial distribution of Rt, expected dist.Distribution + Rt_transform : any + Transformation applied to the sampled Rt0, expected + AbstractTransform + Rt_rw_dist : any + Randomwalk process, expected dist.Distribution. + + Returns + ------- + None + + Raises + ------ + AssertionError + If Rt0_dist or Rt_rw_dist are not dist.Distribution or if + Rt_transform is not AbstractTransform. + + """ assert isinstance(Rt0_dist, dist.Distribution) assert isinstance(Rt_transform, AbstractTransform) assert isinstance(Rt_rw_dist, dist.Distribution) @@ -61,7 +89,8 @@ def sample( n_timepoints : int Number of timepoints to sample. **kwargs : dict, optional - Ignored. + Additional keyword arguments passed through to internal sample() + calls, should there be any. Returns ------- diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index f395212a..59e37973 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -48,7 +48,8 @@ def sample( init : float, optional Initial point of the walk, by default None **kwargs : dict, optional - Ignored. + Additional keyword arguments passed through to internal sample() + calls, should there be any. Returns ------- @@ -65,4 +66,7 @@ def sample( @staticmethod def validate(): + """ + Validates inputted parameters, implementation pending. + """ return None