diff --git a/docs/source/msei_reference/index.rst b/docs/source/msei_reference/index.rst index f7fae05d..323874ca 100644 --- a/docs/source/msei_reference/index.rst +++ b/docs/source/msei_reference/index.rst @@ -7,6 +7,7 @@ Reference model latent process + randomvariable observation datasets msei diff --git a/docs/source/msei_reference/randomvariable.rst b/docs/source/msei_reference/randomvariable.rst new file mode 100644 index 00000000..3ffe44d0 --- /dev/null +++ b/docs/source/msei_reference/randomvariable.rst @@ -0,0 +1,7 @@ +Random Variables +=========== + +.. automodule:: pyrenew.randomvariable + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index e9a3fcba..bf262094 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -25,11 +25,8 @@ from pyrenew.latent import ( from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.metaclass import ( - RandomVariable, - DistributionalRV, - TransformedRandomVariable, -) +from pyrenew.metaclass import RandomVariable +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable import pyrenew.transformation as t from numpyro.infer.reparam import LocScaleReparam ``` @@ -64,7 +61,7 @@ flowchart LR subgraph latent[Latent module] inf["latent_infections_rv\n(Infections)"] - i0["I0_rv\n(DistributionalRV)"] + i0["I0_rv\n(DistributionalVariable)"] end subgraph process[Process module] @@ -126,7 +123,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array) # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(2.5, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -142,17 +139,17 @@ class MyRt(RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = TransformedRandomVariable( + rt_rv = TransformedVariable( name="log_rt_random_walk", base_rv=RandomWalk( name="log_rt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - rt_init_rv = DistributionalRV( + rt_init_rv = DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index 43228731..4b4c6f8f 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -51,7 +51,7 @@ inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy() ```{python} # | label: latent-hosp # | code-fold: true -from pyrenew import latent, deterministic, metaclass +from pyrenew import latent, deterministic, randomvariable import jax.numpy as jnp import numpyro.distributions as dist @@ -59,7 +59,7 @@ inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) -hosp_rate = metaclass.DistributionalRV( +hosp_rate = randomvariable.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) @@ -81,7 +81,7 @@ n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1 I0 = InfectionInitializationProcess( "I0_initialization", - metaclass.DistributionalRV( + randomvariable.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), @@ -113,11 +113,11 @@ class MyRt(metaclass.RandomVariable): sd_rt, *_ = self.sd_rv() # Random walk step - step_rv = metaclass.DistributionalRV( + step_rv = randomvariable.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) ) - rt_init_rv = metaclass.DistributionalRV( + rt_init_rv = randomvariable.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) @@ -128,7 +128,7 @@ class MyRt(metaclass.RandomVariable): ) # Transforming the random walk to the Rt scale - rt_rv = metaclass.TransformedRandomVariable( + rt_rv = randomvariable.TransformedVariable( name="Rt_rv", base_rv=base_rv, transforms=transformation.ExpTransform(), @@ -139,7 +139,7 @@ class MyRt(metaclass.RandomVariable): rtproc = MyRt( - metaclass.DistributionalRV( + randomvariable.DistributionalVariable( name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025) ) ) @@ -152,9 +152,9 @@ rtproc = MyRt( # | code-fold: true # we place a log-Normal prior on the concentration # parameter of the negative binomial. -nb_conc_rv = metaclass.TransformedRandomVariable( +nb_conc_rv = randomvariable.TransformedVariable( "concentration", - metaclass.DistributionalRV( + randomvariable.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), @@ -212,16 +212,16 @@ out = hosp_model.plot_posterior( We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect. -For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedRandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]: +For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]: [^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html). ```{python} # | label: weekly-effect # Instantiating the day-of-the-week effect -dayofweek_effect = metaclass.TransformedRandomVariable( +dayofweek_effect = randomvariable.TransformedVariable( name="dayofweek_effect", - base_rv=metaclass.DistributionalRV( + base_rv=randomvariable.DistributionalVariable( name="dayofweek_effect_raw", distribution=dist.Dirichlet(jnp.ones(7)), ), diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 14615485..f81de653 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -29,11 +29,8 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import InfectionsWithFeedback from pyrenew.model import RtInfectionsRenewalModel from pyrenew.process import RandomWalk -from pyrenew.metaclass import ( - RandomVariable, - DistributionalRV, - TransformedRandomVariable, -) +from pyrenew.metaclass import RandomVariable +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsExponentialGrowth, @@ -53,7 +50,7 @@ feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( gen_int_array.size, DeterministicVariable(name="rate", value=0.05), @@ -75,17 +72,17 @@ class MyRt(RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = TransformedRandomVariable( + rt_rv = TransformedVariable( name="log_rt_random_walk", base_rv=RandomWalk( name="log_rt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - rt_init_rv = DistributionalRV( + rt_init_rv = DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index c9dca3b0..07c38644 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -137,7 +137,7 @@ With these two in hand, we can start building the model. First, we will define t ```{python} # | label: latent-hosp -from pyrenew import latent, deterministic, metaclass +from pyrenew import latent, deterministic, metaclass, randomvariable import jax.numpy as jnp import numpyro.distributions as dist @@ -145,7 +145,7 @@ inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) -hosp_rate = metaclass.DistributionalRV( +hosp_rate = randomvariable.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) @@ -155,7 +155,7 @@ latent_hosp = latent.HospitalAdmissions( ) ``` -The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: +The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalVariable` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: ```{python} # | label: initializing-rest-of-model @@ -171,7 +171,7 @@ latent_inf = latent.Infections() n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1 I0 = InfectionInitializationProcess( "I0_initialization", - metaclass.DistributionalRV( + randomvariable.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), @@ -194,17 +194,17 @@ class MyRt(metaclass.RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = metaclass.TransformedRandomVariable( + rt_rv = randomvariable.TransformedVariable( name="log_rt_random_walk", base_rv=process.RandomWalk( name="log_rt", - step_rv=metaclass.DistributionalRV( + step_rv=randomvariable.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=transformation.ExpTransform(), ) - rt_init_rv = metaclass.DistributionalRV( + rt_init_rv = randomvariable.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() @@ -218,9 +218,9 @@ rtproc = MyRt() # we place a log-Normal prior on the concentration # parameter of the negative binomial. -nb_conc_rv = metaclass.TransformedRandomVariable( +nb_conc_rv = randomvariable.TransformedVariable( "concentration", - metaclass.DistributionalRV( + randomvariable.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index bfe3e30d..1603ed59 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -65,7 +65,7 @@ The `PeriodicBroadcaster` class can also be used to repeat a sequence as a whole ```{python} import numpyro.distributions as dist -from pyrenew import transformation, metaclass +from pyrenew import transformation, randomvariable # Building the transformed prior: Dirichlet * 7 mysimplex = dist.TransformedDistribution( @@ -76,7 +76,7 @@ mysimplex = dist.TransformedDistribution( # Constructing the day of week effect dayofweek = process.DayOfWeekEffect( offset=0, - quantity_to_broadcast=metaclass.DistributionalRV( + quantity_to_broadcast=randomvariable.DistributionalVariable( name="simp", distribution=mysimplex ), t_start=0, diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index 424ffead..63d72de1 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -5,21 +5,17 @@ """ from abc import ABCMeta, abstractmethod -from typing import Callable, NamedTuple, Self, get_type_hints +from typing import NamedTuple, get_type_hints import jax import jax.random as jr import matplotlib.pyplot as plt import numpy as np -import numpyro -import numpyro.distributions as dist import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive -from numpyro.infer.reparam import Reparam from pyrenew.mcmcutils import plot_posterior, spread_draws -from pyrenew.transformation import Transform def _assert_type(arg_name: str, value, expected_type) -> None: @@ -276,338 +272,6 @@ def __call__(self, **kwargs): return self.sample(**kwargs) -class DynamicDistributionalRV(RandomVariable): - """ - Wrapper class for random variables that sample - from a single :class:`numpyro.distributions.Distribution` - that is parameterized / instantiated at `sample()` time - (rather than at RandomVariable instantiation time). - """ - - def __init__( - self, - name: str, - distribution_constructor: Callable, - reparam: Reparam = None, - expand_by_shape: tuple = None, - ) -> None: - """ - Default constructor for DynamicDistributionalRV. - - Parameters - ---------- - name : str - Name of the random variable. - distribution_constructor : Callable - Callable that returns a concrete parametrized - numpyro.Distributions.distribution instance. - reparam : numpyro.infer.reparam.Reparam - If not None, reparameterize sampling - from the distribution according to the - given numpyro reparameterizer - expand_by_shape : tuple, optional - If not None, call :meth:`expand_by()` on the - underlying distribution once it is instianted - with the given `expand_by_shape`. - Default None. - - Returns - ------- - None - """ - - self.name = name - self.validate(distribution_constructor) - self.distribution_constructor = distribution_constructor - if reparam is not None: - self.reparam_dict = {self.name: reparam} - else: - self.reparam_dict = {} - if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)): - raise ValueError( - "expand_by_shape must be a tuple or be None ", - f"Got {type(expand_by_shape)}", - ) - self.expand_by_shape = expand_by_shape - - return None - - @staticmethod - def validate(distribution_constructor: any) -> None: - """ - Confirm that the distribution_constructor is - callable. - - Parameters - ---------- - distribution_constructor : any - Putative distribution_constructor to validate. - - Returns - ------- - None or raises a ValueError - """ - if not callable(distribution_constructor): - raise ValueError( - "To instantiate a DynamicDistributionalRV, ", - "one must provide a Callable that returns a " - "numpyro.distributions.Distribution as the " - "distribution_constructor argument. " - f"Got {type(distribution_constructor)}, which " - "does not appear to be callable", - ) - return None - - def sample( - self, - *args, - obs: ArrayLike = None, - **kwargs, - ) -> tuple: - """ - Sample from the distributional rv. - - Parameters - ---------- - *args : - Positional arguments passed to self.distribution_constructor - obs : ArrayLike, optional - Observations passed as the `obs` argument to - :meth:`numpyro.sample()`. Default `None`. - **kwargs : dict, optional - Keyword arguments passed to self.distribution_constructor - - Returns - ------- - SampledValue - Containing a sample from the distribution. - """ - distribution = self.distribution_constructor(*args, **kwargs) - if self.expand_by_shape is not None: - distribution = distribution.expand_by(self.expand_by_shape) - with numpyro.handlers.reparam(config=self.reparam_dict): - sample = numpyro.sample( - name=self.name, - fn=distribution, - obs=obs, - ) - return ( - SampledValue( - sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - def expand_by(self, sample_shape) -> Self: - """ - Expand the distribution by a given - shape_shape, if possible. Returns a - new DynamicDistributionalRV whose underlying - distribution will be expanded by the given shape - at sample() time. - - Parameters - ---------- - sample_shape : tuple - Sample shape by which to expand the distribution. - Passed to the expand_by() method of - :class:`numpyro.distributions.Distribution` - after the distribution is instantiated. - - Returns - ------- - DynamicDistributionalRV - Whose underlying distribution will be expanded by - the given sample shape at sampling time. - """ - return DynamicDistributionalRV( - name=self.name, - distribution_constructor=self.distribution_constructor, - reparam=self.reparam_dict.get(self.name, None), - expand_by_shape=sample_shape, - ) - - -class StaticDistributionalRV(RandomVariable): - """ - Wrapper class for random variables that sample - from a single :class:`numpyro.distributions.Distribution` - that is parameterized / instantiated at RandomVariable - instantiation time (rather than at `sample()`-ing time). - """ - - def __init__( - self, - name: str, - distribution: numpyro.distributions.Distribution, - reparam: Reparam = None, - ) -> None: - """ - Default constructor for DistributionalRV. - - Parameters - ---------- - name : str - Name of the random variable. - distribution : numpyro.distributions.Distribution - Distribution of the random variable. - reparam : numpyro.infer.reparam.Reparam - If not None, reparameterize sampling - from the distribution according to the - given numpyro reparameterizer - - Returns - ------- - None - """ - - self.name = name - self.validate(distribution) - self.distribution = distribution - if reparam is not None: - self.reparam_dict = {self.name: reparam} - else: - self.reparam_dict = {} - - return None - - @staticmethod - def validate(distribution: any) -> None: - """ - Validation of the distribution. - """ - if not isinstance(distribution, numpyro.distributions.Distribution): - raise ValueError( - "distribution should be an instance of " - "numpyro.distributions.Distribution, got " - "{type(distribution)}" - ) - - return None - - def sample( - self, - obs: ArrayLike | None = None, - **kwargs, - ) -> tuple: - """ - Sample from the distribution. - - Parameters - ---------- - obs : ArrayLike, optional - Observations passed as the `obs` argument to - :meth:`numpyro.sample()`. Default `None`. - **kwargs : dict, optional - Additional keyword arguments passed through - to internal sample calls, should there be any. - - Returns - ------- - SampledValue - Containing a sample from the distribution. - """ - with numpyro.handlers.reparam(config=self.reparam_dict): - sample = numpyro.sample( - name=self.name, - fn=self.distribution, - obs=obs, - ) - return ( - SampledValue( - sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - def expand_by(self, sample_shape) -> Self: - """ - Expand the distribution by the given sample_shape, - if possible. Returns a new StaticDistributionalRV - whose underlying distribution has been expanded by - the given sample_shape via - :meth:`~numpyro.distributions.Distribution.expand_by()` - - Parameters - ---------- - sample_shape : tuple - Sample shape for the expansion. Passed to the - :meth:`expand_by()` method of - :class:`numpyro.distributions.Distribution`. - - Returns - ------- - StaticDistributionalRV - Whose underlying distribution has been expanded by - the given sample shape. - """ - if not isinstance(sample_shape, tuple): - raise ValueError( - "sample_shape for expand()-ing " - "a DistributionalRV must be a " - f"tuple. Got {type(sample_shape)}" - ) - return StaticDistributionalRV( - name=self.name, - distribution=self.distribution.expand_by(sample_shape), - reparam=self.reparam_dict.get(self.name, None), - ) - - -def DistributionalRV( - name: str, - distribution: numpyro.distributions.Distribution | Callable, - reparam: Reparam = None, -) -> RandomVariable: - """ - Factory function to generate Distributional RandomVariables, - either static or dynamic. - - Parameters - ---------- - name : str - Name of the random variable. - - distribution: numpyro.distributions.Distribution | Callable - Either numpyro.distributions.Distribution instance - given the static distribution of the random variable or - a callable that returns a parameterized - numpyro.distributions.Distribution when called, which - allows for dynamically-parameterized DistributionalRVs, - e.g. a Normal distribution with an inferred location and - scale. - - reparam : numpyro.infer.reparam.Reparam - If not None, reparameterize sampling - from the distribution according to the - given numpyro reparameterizer - - Returns - ------- - DynamicDistributionalRV | StaticDistributionalRV or - raises a ValueError if a distribution cannot be constructed. - """ - if isinstance(distribution, dist.Distribution): - return StaticDistributionalRV( - name=name, distribution=distribution, reparam=reparam - ) - elif callable(distribution): - return DynamicDistributionalRV( - name=name, distribution_constructor=distribution, reparam=reparam - ) - else: - raise ValueError( - "distribution argument to DistributionalRV " - "must be either a numpyro.distributions.Distribution " - "(for instantiating a static DistributionalRV) " - "or a callable that returns a " - "numpyro.distributions.Distribution (for " - "a dynamic DistributionalRV" - ) - - class Model(metaclass=ABCMeta): """Abstract base class for models""" @@ -891,137 +555,3 @@ def prior_predictive( ) return predictive(rng_key, **kwargs) - - -class TransformedRandomVariable(RandomVariable): - """ - Class to represent RandomVariables defined - by taking the output of another RV's - :meth:`RandomVariable.sample()` method - and transforming it by a given transformation - (typically a :class:`Transform`) - """ - - def __init__( - self, - name: str, - base_rv: RandomVariable, - transforms: Transform | tuple[Transform], - ): - """ - Default constructor - - Parameters - ---------- - name : str - A name for the random variable instance. - base_rv : RandomVariable - The underlying (untransformed) RandomVariable. - transforms : Transform - Transformation or tuple of transformations - to apply to the output of - `base_rv.sample()`; single values will be coerced to - a length-one tuple. If a tuple, should be the same - length as the tuple returned by `base_rv.sample()`. - - Returns - ------- - None - """ - self.name = name - self.base_rv = base_rv - - if not isinstance(transforms, tuple): - transforms = (transforms,) - self.transforms = transforms - self.validate() - - def sample(self, record=False, **kwargs) -> tuple: - """ - Sample method. Call self.base_rv.sample() - and then apply the transforms specified - in self.transforms. - - Parameters - ---------- - record : bool, optional - Whether to record the value of the deterministic - RandomVariable. Defaults to False. - **kwargs : - Keyword arguments passed to self.base_rv.sample() - - Returns - ------- - tuple of the same length as the tuple returned by - self.base_rv.sample() - """ - - untransformed_values = self.base_rv.sample(**kwargs) - transformed_values = tuple( - SampledValue( - t(uv.value), - t_start=self.t_start, - t_unit=self.t_unit, - ) - for t, uv in zip(self.transforms, untransformed_values) - ) - - if record: - if len(untransformed_values) == 1: - numpyro.deterministic(self.name, transformed_values[0].value) - else: - suffixes = ( - untransformed_values._fields - if hasattr(untransformed_values, "_fields") - else range(len(transformed_values)) - ) - for suffix, tv in zip(suffixes, transformed_values): - numpyro.deterministic(f"{self.name}_{suffix}", tv.value) - - return transformed_values - - def sample_length(self): - """ - Sample length for a transformed - random variable must be equal to the - length of self.transforms or - validation will fail. - - Returns - ------- - int - Equal to the length self.transforms - """ - return len(self.transforms) - - def validate(self): - """ - Perform validation checks on a - TransformedRandomVariable instance, - confirming that all transformations - are callable and that the number of - transformations is equal to the sample - length of the base random variable. - - Returns - ------- - None - on successful validation, or raise a ValueError - """ - for t in self.transforms: - if not callable(t): - raise ValueError( - "All entries in self.transforms " "must be callable" - ) - if hasattr(self.base_rv, "sample_length"): - n_transforms = len(self.transforms) - n_entries = self.base_rv.sample_length() - if not n_transforms == n_entries: - raise ValueError( - "There must be exactly as many transformations " - "specified as entries self.transforms as there are " - "entries in the tuple returned by " - "self.base_rv.sample()." - f"Got {n_transforms} transforms and {n_entries} " - "entries" - ) diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index 2f868ada..10adfa9c 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -4,7 +4,8 @@ import numpyro.distributions as dist from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.randomvariable import DistributionalVariable class IIDRandomSequence(RandomVariable): @@ -130,7 +131,7 @@ def __init__( see :class:`IIDRandomSequence`. element_rv_name: str Name for the internal element_rv, here a - DistributionalRV encoding a + DistributionalVariable encoding a standard Normal (mean = 0, sd = 1) distribution. @@ -139,7 +140,7 @@ def __init__( None """ super().__init__( - element_rv=DistributionalRV( + element_rv=DistributionalVariable( name=element_rv_name, distribution=dist.Normal(0, 1) ), ) diff --git a/pyrenew/process/randomwalk.py b/pyrenew/process/randomwalk.py index a9fa472e..6b0a763d 100644 --- a/pyrenew/process/randomwalk.py +++ b/pyrenew/process/randomwalk.py @@ -3,9 +3,10 @@ import numpyro.distributions as dist -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import RandomVariable from pyrenew.process.differencedprocess import DifferencedProcess from pyrenew.process.iidrandomsequence import IIDRandomSequence +from pyrenew.randomvariable import DistributionalVariable class RandomWalk(DifferencedProcess): @@ -69,7 +70,7 @@ def __init__( Parameters ---------- step_rv_name : - Name for the DistributionalRV + Name for the DistributionalVariable from which the Normal(0, 1) steps are sampled. **kwargs: @@ -80,7 +81,7 @@ def __init__( None """ super().__init__( - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name=step_rv_name, distribution=dist.Normal(0.0, 1.0) ), **kwargs, diff --git a/pyrenew/randomvariable/__init__.py b/pyrenew/randomvariable/__init__.py new file mode 100644 index 00000000..4f154b2d --- /dev/null +++ b/pyrenew/randomvariable/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +# numpydoc ignore=GL08 + +from pyrenew.randomvariable.distributionalvariable import ( + DistributionalVariable, + DynamicDistributionalVariable, + StaticDistributionalVariable, +) +from pyrenew.randomvariable.transformedvariable import TransformedVariable + +__all__ = [ + "DistributionalVariable", + "StaticDistributionalVariable", + "DynamicDistributionalVariable", + "TransformedVariable", +] diff --git a/pyrenew/randomvariable/distributionalvariable.py b/pyrenew/randomvariable/distributionalvariable.py new file mode 100644 index 00000000..671dde08 --- /dev/null +++ b/pyrenew/randomvariable/distributionalvariable.py @@ -0,0 +1,342 @@ +# numpydoc ignore=GL08 + +from typing import Callable, Self + +import numpyro +import numpyro.distributions as dist +from jax.typing import ArrayLike +from numpyro.infer.reparam import Reparam + +from pyrenew.metaclass import RandomVariable, SampledValue + + +class DynamicDistributionalVariable(RandomVariable): + """ + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution` + that is parameterized / instantiated at `sample()` time + (rather than at RandomVariable instantiation time). + """ + + def __init__( + self, + name: str, + distribution_constructor: Callable, + reparam: Reparam = None, + expand_by_shape: tuple = None, + ) -> None: + """ + Default constructor for DynamicDistributionalVariable. + + Parameters + ---------- + name : str + Name of the random variable. + distribution_constructor : Callable + Callable that returns a concrete parametrized + numpyro.Distributions.distribution instance. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + expand_by_shape : tuple, optional + If not None, call :meth:`expand_by()` on the + underlying distribution once it is instianted + with the given `expand_by_shape`. + Default None. + + Returns + ------- + None + """ + + self.name = name + self.validate(distribution_constructor) + self.distribution_constructor = distribution_constructor + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} + if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)): + raise ValueError( + "expand_by_shape must be a tuple or be None ", + f"Got {type(expand_by_shape)}", + ) + self.expand_by_shape = expand_by_shape + + return None + + @staticmethod + def validate(distribution_constructor: any) -> None: + """ + Confirm that the distribution_constructor is + callable. + + Parameters + ---------- + distribution_constructor : any + Putative distribution_constructor to validate. + + Returns + ------- + None or raises a ValueError + """ + if not callable(distribution_constructor): + raise ValueError( + "To instantiate a DynamicDistributionalVariable, ", + "one must provide a Callable that returns a " + "numpyro.distributions.Distribution as the " + "distribution_constructor argument. " + f"Got {type(distribution_constructor)}, which " + "does not appear to be callable", + ) + return None + + def sample( + self, + *args, + obs: ArrayLike = None, + **kwargs, + ) -> tuple: + """ + Sample from the distributional rv. + + Parameters + ---------- + *args : + Positional arguments passed to self.distribution_constructor + obs : ArrayLike, optional + Observations passed as the `obs` argument to + :meth:`numpyro.sample()`. Default `None`. + **kwargs : dict, optional + Keyword arguments passed to self.distribution_constructor + + Returns + ------- + SampledValue + Containing a sample from the distribution. + """ + distribution = self.distribution_constructor(*args, **kwargs) + if self.expand_by_shape is not None: + distribution = distribution.expand_by(self.expand_by_shape) + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=distribution, + obs=obs, + ) + return ( + SampledValue( + sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + def expand_by(self, sample_shape) -> Self: + """ + Expand the distribution by a given + shape_shape, if possible. Returns a + new DynamicDistributionalVariable whose underlying + distribution will be expanded by the given shape + at sample() time. + + Parameters + ---------- + sample_shape : tuple + Sample shape by which to expand the distribution. + Passed to the expand_by() method of + :class:`numpyro.distributions.Distribution` + after the distribution is instantiated. + + Returns + ------- + DynamicDistributionalVariable + Whose underlying distribution will be expanded by + the given sample shape at sampling time. + """ + return DynamicDistributionalVariable( + name=self.name, + distribution_constructor=self.distribution_constructor, + reparam=self.reparam_dict.get(self.name, None), + expand_by_shape=sample_shape, + ) + + +class StaticDistributionalVariable(RandomVariable): + """ + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution` + that is parameterized / instantiated at RandomVariable + instantiation time (rather than at `sample()`-ing time). + """ + + def __init__( + self, + name: str, + distribution: numpyro.distributions.Distribution, + reparam: Reparam = None, + ) -> None: + """ + Default constructor for DistributionalVariable. + + Parameters + ---------- + name : str + Name of the random variable. + distribution : numpyro.distributions.Distribution + Distribution of the random variable. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + + Returns + ------- + None + """ + + self.name = name + self.validate(distribution) + self.distribution = distribution + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} + + return None + + @staticmethod + def validate(distribution: any) -> None: + """ + Validation of the distribution. + """ + if not isinstance(distribution, numpyro.distributions.Distribution): + raise ValueError( + "distribution should be an instance of " + "numpyro.distributions.Distribution, got " + "{type(distribution)}" + ) + + return None + + def sample( + self, + obs: ArrayLike | None = None, + **kwargs, + ) -> tuple: + """ + Sample from the distribution. + + Parameters + ---------- + obs : ArrayLike, optional + Observations passed as the `obs` argument to + :meth:`numpyro.sample()`. Default `None`. + **kwargs : dict, optional + Additional keyword arguments passed through + to internal sample calls, should there be any. + + Returns + ------- + SampledValue + Containing a sample from the distribution. + """ + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=self.distribution, + obs=obs, + ) + return ( + SampledValue( + sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + def expand_by(self, sample_shape) -> Self: + """ + Expand the distribution by the given sample_shape, + if possible. Returns a new StaticDistributionalVariable + whose underlying distribution has been expanded by + the given sample_shape via + :meth:`~numpyro.distributions.Distribution.expand_by()` + + Parameters + ---------- + sample_shape : tuple + Sample shape for the expansion. Passed to the + :meth:`expand_by()` method of + :class:`numpyro.distributions.Distribution`. + + Returns + ------- + StaticDistributionalVariable + Whose underlying distribution has been expanded by + the given sample shape. + """ + if not isinstance(sample_shape, tuple): + raise ValueError( + "sample_shape for expand()-ing " + "a DistributionalVariable must be a " + f"tuple. Got {type(sample_shape)}" + ) + return StaticDistributionalVariable( + name=self.name, + distribution=self.distribution.expand_by(sample_shape), + reparam=self.reparam_dict.get(self.name, None), + ) + + +def DistributionalVariable( + name: str, + distribution: numpyro.distributions.Distribution | Callable, + reparam: Reparam = None, +) -> RandomVariable: + """ + Factory function to generate Distributional RandomVariables, + either static or dynamic. + + Parameters + ---------- + name : str + Name of the random variable. + + distribution: numpyro.distributions.Distribution | Callable + Either numpyro.distributions.Distribution instance + given the static distribution of the random variable or + a callable that returns a parameterized + numpyro.distributions.Distribution when called, which + allows for dynamically-parameterized DistributionalVariables, + e.g. a Normal distribution with an inferred location and + scale. + + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + + Returns + ------- + DynamicDistributionalVariable | StaticDistributionalVariable or + raises a ValueError if a distribution cannot be constructed. + """ + if isinstance(distribution, dist.Distribution): + return StaticDistributionalVariable( + name=name, distribution=distribution, reparam=reparam + ) + elif callable(distribution): + return DynamicDistributionalVariable( + name=name, distribution_constructor=distribution, reparam=reparam + ) + else: + raise ValueError( + "distribution argument to DistributionalVariable " + "must be either a numpyro.distributions.Distribution " + "(for instantiating a static DistributionalVariable) " + "or a callable that returns a " + "numpyro.distributions.Distribution (for " + "a dynamic DistributionalVariable" + ) diff --git a/pyrenew/randomvariable/transformedvariable.py b/pyrenew/randomvariable/transformedvariable.py new file mode 100644 index 00000000..36519a24 --- /dev/null +++ b/pyrenew/randomvariable/transformedvariable.py @@ -0,0 +1,140 @@ +# numpydoc ignore=GL08 + +import numpyro + +from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.transformation import Transform + + +class TransformedVariable(RandomVariable): + """ + Class to represent RandomVariables defined + by taking the output of another RV's + :meth:`RandomVariable.sample()` method + and transforming it by a given transformation + (typically a :class:`Transform`) + """ + + def __init__( + self, + name: str, + base_rv: RandomVariable, + transforms: Transform | tuple[Transform], + ): + """ + Default constructor + + Parameters + ---------- + name : str + A name for the random variable instance. + base_rv : RandomVariable + The underlying (untransformed) RandomVariable. + transforms : Transform + Transformation or tuple of transformations + to apply to the output of + `base_rv.sample()`; single values will be coerced to + a length-one tuple. If a tuple, should be the same + length as the tuple returned by `base_rv.sample()`. + + Returns + ------- + None + """ + self.name = name + self.base_rv = base_rv + + if not isinstance(transforms, tuple): + transforms = (transforms,) + self.transforms = transforms + self.validate() + + def sample(self, record=False, **kwargs) -> tuple: + """ + Sample method. Call self.base_rv.sample() + and then apply the transforms specified + in self.transforms. + + Parameters + ---------- + record : bool, optional + Whether to record the value of the deterministic + RandomVariable. Defaults to False. + **kwargs : + Keyword arguments passed to self.base_rv.sample() + + Returns + ------- + tuple of the same length as the tuple returned by + self.base_rv.sample() + """ + + untransformed_values = self.base_rv.sample(**kwargs) + transformed_values = tuple( + SampledValue( + t(uv.value), + t_start=self.t_start, + t_unit=self.t_unit, + ) + for t, uv in zip(self.transforms, untransformed_values) + ) + + if record: + if len(untransformed_values) == 1: + numpyro.deterministic(self.name, transformed_values[0].value) + else: + suffixes = ( + untransformed_values._fields + if hasattr(untransformed_values, "_fields") + else range(len(transformed_values)) + ) + for suffix, tv in zip(suffixes, transformed_values): + numpyro.deterministic(f"{self.name}_{suffix}", tv.value) + + return transformed_values + + def sample_length(self): + """ + Sample length for a transformed + random variable must be equal to the + length of self.transforms or + validation will fail. + + Returns + ------- + int + Equal to the length self.transforms + """ + return len(self.transforms) + + def validate(self): + """ + Perform validation checks on a + TransformedVariable instance, + confirming that all transformations + are callable and that the number of + transformations is equal to the sample + length of the base random variable. + + Returns + ------- + None + on successful validation, or raise a ValueError + """ + for t in self.transforms: + if not callable(t): + raise ValueError( + "All entries in self.transforms " "must be callable" + ) + if hasattr(self.base_rv, "sample_length"): + n_transforms = len(self.transforms) + n_entries = self.base_rv.sample_length() + if not n_transforms == n_entries: + raise ValueError( + "There must be exactly as many transformations " + "specified as entries self.transforms as there are " + "entries in the tuple returned by " + "self.base_rv.sample()." + f"Got {n_transforms} transforms and {n_entries} " + "entries" + ) diff --git a/test/test_assert_sample_and_rtype.py b/test/test_assert_sample_and_rtype.py index 69a59f0f..d0f9ee8a 100644 --- a/test/test_assert_sample_and_rtype.py +++ b/test/test_assert_sample_and_rtype.py @@ -9,11 +9,11 @@ from pyrenew.deterministic import DeterministicVariable, NullObservation from pyrenew.metaclass import ( - DistributionalRV, RandomVariable, SampledValue, _assert_sample_and_rtype, ) +from pyrenew.randomvariable import DistributionalVariable class RVreturnsTuple(RandomVariable): @@ -93,7 +93,7 @@ def test_input_rv(): # numpydoc ignore=GL08 valid_rv = [ NullObservation(), DeterministicVariable(name="rv1", value=jnp.array([1, 2, 3, 4])), - DistributionalRV(name="rv2", distribution=dist.Normal(0, 1)), + DistributionalVariable(name="rv2", distribution=dist.Normal(0, 1)), ] not_rv = jnp.array([1]) diff --git a/test/test_assert_type.py b/test/test_assert_type.py index 7a41cdc8..a885cef3 100644 --- a/test/test_assert_type.py +++ b/test/test_assert_type.py @@ -3,7 +3,8 @@ import numpyro.distributions as dist import pytest -from pyrenew.metaclass import DistributionalRV, RandomVariable, _assert_type +from pyrenew.metaclass import RandomVariable, _assert_type +from pyrenew.randomvariable import DistributionalVariable def test_valid_assertion_types(): @@ -15,7 +16,7 @@ def test_valid_assertion_types(): 5, "Hello", (1,), - DistributionalRV(name="rv", distribution=dist.Beta(1, 1)), + DistributionalVariable(name="rv", distribution=dist.Beta(1, 1)), ] arg_names = ["input_int", "input_string", "input_tuple", "input_rv"] input_types = [int, str, tuple, RandomVariable] diff --git a/test/test_differenced_process.py b/test/test_differenced_process.py index 63c28073..ba4e95c9 100644 --- a/test/test_differenced_process.py +++ b/test/test_differenced_process.py @@ -10,12 +10,12 @@ from numpy.testing import assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable, NullVariable -from pyrenew.metaclass import DistributionalRV from pyrenew.process import ( DifferencedProcess, IIDRandomSequence, StandardNormalSequence, ) +from pyrenew.randomvariable import DistributionalVariable @pytest.mark.parametrize( @@ -155,7 +155,7 @@ def test_manual_integrator_correctness(diffs, inits, expected_solution): [ [ IIDRandomSequence( - DistributionalRV("element_dist", dist.Cauchy(0.02, 0.3)), + DistributionalVariable("element_dist", dist.Cauchy(0.02, 0.3)), ), 3, jnp.array([0.25, 0.67, 5]), diff --git a/test/test_distributional_rv.py b/test/test_distributional_rv.py index 0a0b4d2c..cebe6f8e 100644 --- a/test/test_distributional_rv.py +++ b/test/test_distributional_rv.py @@ -1,6 +1,7 @@ """ Tests for the distributional RV classes """ + import jax.numpy as jnp import numpyro import numpyro.distributions as dist @@ -8,17 +9,17 @@ from numpy.testing import assert_array_equal from numpyro.distributions import ExpandedDistribution -from pyrenew.metaclass import ( - DistributionalRV, - DynamicDistributionalRV, - StaticDistributionalRV, +from pyrenew.randomvariable import ( + DistributionalVariable, + DynamicDistributionalVariable, + StaticDistributionalVariable, ) class NonCallableTestClass: """ Generic non-callable object to test - callable checking for DynamicDistributionalRV. + callable checking for DynamicDistributionalVariable. """ def __init__(self): @@ -37,9 +38,11 @@ def test_invalid_constructor_args(not_a_dist): """ with pytest.raises( - ValueError, match="distribution argument to DistributionalRV" + ValueError, match="distribution argument to DistributionalVariable" ): - DistributionalRV(name="this should fail", distribution=not_a_dist) + DistributionalVariable( + name="this should fail", distribution=not_a_dist + ) with pytest.raises( ValueError, match=( @@ -47,9 +50,9 @@ def test_invalid_constructor_args(not_a_dist): "numpyro.distributions.Distribution" ), ): - StaticDistributionalRV.validate(not_a_dist) + StaticDistributionalVariable.validate(not_a_dist) with pytest.raises(ValueError, match="must provide a Callable"): - DynamicDistributionalRV.validate(not_a_dist) + DynamicDistributionalVariable.validate(not_a_dist) @pytest.mark.parametrize( @@ -63,18 +66,18 @@ def test_invalid_constructor_args(not_a_dist): def test_factory_triage(valid_static_dist_arg, valid_dynamic_dist_arg): """ Test that passing a numpyro.distributions.Distribution - instance to the DistributionalRV factory instaniates - a StaticDistributionalRV, while passing a callable - instaniates a DynamicDistributionalRV + instance to the DistributionalVariable factory instaniates + a StaticDistributionalVariable, while passing a callable + instaniates a DynamicDistributionalVariable """ - static = DistributionalRV( + static = DistributionalVariable( name="test static", distribution=valid_static_dist_arg ) - assert isinstance(static, StaticDistributionalRV) - dynamic = DistributionalRV( + assert isinstance(static, StaticDistributionalVariable) + dynamic = DistributionalVariable( name="test dynamic", distribution=valid_dynamic_dist_arg ) - assert isinstance(dynamic, DynamicDistributionalRV) + assert isinstance(dynamic, DynamicDistributionalVariable) @pytest.mark.parametrize( @@ -97,12 +100,12 @@ def test_expand_by(dist, params, expand_by_shape): Test the expand_by method for static distributional RVs. """ - static = DistributionalRV(name="static", distribution=dist(**params)) - dynamic = DistributionalRV(name="dynamic", distribution=dist) + static = DistributionalVariable(name="static", distribution=dist(**params)) + dynamic = DistributionalVariable(name="dynamic", distribution=dist) expanded_static = static.expand_by(expand_by_shape) expanded_dynamic = dynamic.expand_by(expand_by_shape) - assert isinstance(expanded_dynamic, DynamicDistributionalRV) + assert isinstance(expanded_dynamic, DynamicDistributionalVariable) assert dynamic.expand_by_shape is None assert isinstance(expanded_dynamic.expand_by_shape, tuple) assert expanded_dynamic.expand_by_shape == expand_by_shape @@ -112,7 +115,7 @@ def test_expand_by(dist, params, expand_by_shape): == expanded_dynamic.distribution_constructor ) - assert isinstance(expanded_static, StaticDistributionalRV) + assert isinstance(expanded_static, StaticDistributionalVariable) assert isinstance(expanded_static.distribution, ExpandedDistribution) assert expanded_static.distribution.batch_shape == ( expand_by_shape + static.distribution.batch_shape @@ -140,15 +143,15 @@ def test_expand_by(dist, params, expand_by_shape): ) def test_sampling_equivalent(dist, params): """ - Test that sampling a DynamicDistributionalRV + Test that sampling a DynamicDistributionalVariable with a given parameterization is equivalent to - sampling a StaticDistributionalRV with the + sampling a StaticDistributionalVariable with the same parameterization and the same random seed """ - static = DistributionalRV(name="static", distribution=dist(**params)) - dynamic = DistributionalRV(name="dynamic", distribution=dist) - assert isinstance(static, StaticDistributionalRV) - assert isinstance(dynamic, DynamicDistributionalRV) + static = DistributionalVariable(name="static", distribution=dist(**params)) + dynamic = DistributionalVariable(name="dynamic", distribution=dist) + assert isinstance(static, StaticDistributionalVariable) + assert isinstance(dynamic, DynamicDistributionalVariable) with numpyro.handlers.seed(rng_seed=5): static_samp, *_ = static() with numpyro.handlers.seed(rng_seed=5): diff --git a/test/test_forecast.py b/test/test_forecast.py index beef0273..d8d1d55c 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -14,9 +14,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable def test_forecast(): @@ -28,7 +28,7 @@ def test_forecast(): gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_iid_random_sequence.py b/test/test_iid_random_sequence.py index eb6d943c..73b683aa 100755 --- a/test/test_iid_random_sequence.py +++ b/test/test_iid_random_sequence.py @@ -6,12 +6,12 @@ import pytest from scipy.stats import kstest -from pyrenew.metaclass import ( - DistributionalRV, - SampledValue, - StaticDistributionalRV, -) +from pyrenew.metaclass import SampledValue from pyrenew.process import IIDRandomSequence, StandardNormalSequence +from pyrenew.randomvariable import ( + DistributionalVariable, + StaticDistributionalVariable, +) @pytest.mark.parametrize( @@ -29,7 +29,7 @@ def test_iidrandomsequence_with_dist_rv(distribution, n): a distributional RV, including with array-valued distributions """ - element_rv = DistributionalRV("el_rv", distribution=distribution) + element_rv = DistributionalVariable("el_rv", distribution=distribution) rseq = IIDRandomSequence(element_rv=element_rv) if distribution.batch_shape == () or distribution.batch_shape == (1,): expected_shape = (n,) @@ -63,9 +63,9 @@ def test_standard_normal_sequence(): """ norm_seq = StandardNormalSequence("test_norm_elements") - # should be implemented with a DistributionalRV + # should be implemented with a DistributionalVariable # that is a standard normal - assert isinstance(norm_seq.element_rv, StaticDistributionalRV) + assert isinstance(norm_seq.element_rv, StaticDistributionalVariable) assert isinstance(norm_seq.element_rv.distribution, dist.Normal) assert norm_seq.element_rv.distribution.loc == 0.0 assert norm_seq.element_rv.distribution.scale == 1.0 diff --git a/test/test_infection_initialization_process.py b/test/test_infection_initialization_process.py index afe91ef6..069299cd 100644 --- a/test/test_infection_initialization_process.py +++ b/test/test_infection_initialization_process.py @@ -11,7 +11,7 @@ InitializeInfectionsFromVec, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.randomvariable import DistributionalVariable def test_infection_initialization_process(): @@ -20,14 +20,14 @@ def test_infection_initialization_process(): zero_pad_model = InfectionInitializationProcess( "zero_pad_model", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints), t_unit=1, ) exp_model = InfectionInitializationProcess( "exp_model", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( n_timepoints, DeterministicVariable(name="rate", value=0.5) ), diff --git a/test/test_latent_admissions.py b/test/test_latent_admissions.py index 526fbc31..1e82db89 100644 --- a/test/test_latent_admissions.py +++ b/test/test_latent_admissions.py @@ -10,7 +10,8 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalRV, SampledValue +from pyrenew.metaclass import SampledValue +from pyrenew.randomvariable import DistributionalVariable def test_admissions_sample(): @@ -64,7 +65,7 @@ def test_admissions_sample(): hosp1 = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index ffe09cd4..1b0314f8 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -18,9 +18,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable def test_model_basicrenewal_no_timepoints_or_observations(): @@ -36,7 +36,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -72,7 +72,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -111,11 +111,11 @@ def test_model_basicrenewal_no_obs_model(): ) with pytest.raises(ValueError): - _ = DistributionalRV(name="I0", distribution=1) + _ = DistributionalVariable(name="I0", distribution=1) I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -186,7 +186,7 @@ def test_model_basicrenewal_with_obs_model(): I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -240,7 +240,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index bb740944..f6d3d3a2 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -23,9 +23,10 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 @@ -91,7 +92,7 @@ def test_model_hosp_no_timepoints_or_observations(): ), ) - I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) + I0 = DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = SimpleRt() @@ -100,7 +101,7 @@ def test_model_hosp_no_timepoints_or_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -156,7 +157,7 @@ def test_model_hosp_both_timepoints_and_observations(): ), ) - I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) + I0 = DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = SimpleRt() @@ -164,7 +165,7 @@ def test_model_hosp_both_timepoints_and_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -226,7 +227,7 @@ def test_model_hosp_no_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -236,7 +237,7 @@ def test_model_hosp_no_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), @@ -338,7 +339,7 @@ def test_model_hosp_with_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -349,7 +350,7 @@ def test_model_hosp_with_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), @@ -427,7 +428,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -443,7 +444,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): infection_to_admission_interval_rv=inf_hosp, day_of_week_effect_rv=weekday, hospitalization_reporting_ratio_rv=hosp_report_prob_dist, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -518,11 +519,11 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ), ) - n_initialization_points = max(gen_int.size(), inf_hosp.size()) - 1 + n_initialization_points = max(gen_int.size(), inf_hosp.size()) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -534,6 +535,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Other random components total_length = n_obs_to_generate + pad_size + total_length = n_obs_to_generate + pad_size + 1 # gen_int.size() weekday = jnp.array([1, 1, 1, 1, 2, 2, 2]) weekday = weekday / weekday.sum() @@ -553,7 +555,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): infection_to_admission_interval_rv=inf_hosp, day_of_week_effect_rv=weekday, hospitalization_reporting_ratio_rv=hosp_report_prob_dist, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), diff --git a/test/test_predictive.py b/test/test_predictive.py index 5c76b98b..636578bb 100644 --- a/test/test_predictive.py +++ b/test/test_predictive.py @@ -17,15 +17,15 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_random_key.py b/test/test_random_key.py index 6d6cfd43..0b99816f 100644 --- a/test/test_random_key.py +++ b/test/test_random_key.py @@ -19,9 +19,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable def create_test_model(): # numpydoc ignore=GL08 @@ -29,7 +29,7 @@ def create_test_model(): # numpydoc ignore=GL08 gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_random_walk.py b/test/test_random_walk.py index d7e2cabd..6997d679 100755 --- a/test/test_random_walk.py +++ b/test/test_random_walk.py @@ -7,15 +7,16 @@ from numpy.testing import assert_almost_equal, assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import RandomVariable from pyrenew.process import RandomWalk, StandardNormalRandomWalk +from pyrenew.randomvariable import DistributionalVariable @pytest.mark.parametrize( ["element_rv", "init_value"], [ - [DistributionalRV("test_normal", dist.Normal(0.5, 1)), 50.0], - [DistributionalRV("test_cauchy", dist.Cauchy(0.25, 0.25)), -3], + [DistributionalVariable("test_normal", dist.Normal(0.5, 1)), 50.0], + [DistributionalVariable("test_cauchy", dist.Cauchy(0.25, 0.25)), -3], ["test standard normal", jnp.array(3)], ], ) @@ -81,7 +82,7 @@ def test_normal_rw_samples_correctly_distributed(step_mean, step_sd): rw_normal = StandardNormalRandomWalk("test standard normal") else: rw_normal = RandomWalk( - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_dist", distribution=dist.Normal(loc=step_mean, scale=step_sd), ), diff --git a/test/test_transformed_rv_class.py b/test/test_transformed_rv_class.py index 353d59e0..22dd1c2c 100644 --- a/test/test_transformed_rv_class.py +++ b/test/test_transformed_rv_class.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Tests for TransformedRandomVariable class +Tests for TransformedVariable class """ from typing import NamedTuple @@ -13,13 +13,8 @@ from numpy.testing import assert_almost_equal import pyrenew.transformation as t -from pyrenew.metaclass import ( - DistributionalRV, - Model, - RandomVariable, - SampledValue, - TransformedRandomVariable, -) +from pyrenew.metaclass import Model, RandomVariable, SampledValue +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable class LengthTwoRV(RandomVariable): @@ -129,11 +124,11 @@ def sample(self, **kwargs): # numpydoc ignore=GL08 def test_transform_rv_validation(): """ - Test that a TransformedRandomVariable validation + Test that a TransformedVariable validation works as expected. """ - base_rv = DistributionalRV( + base_rv = DistributionalVariable( name="test_normal", distribution=dist.Normal(0, 1) ) base_rv.sample_length = lambda: 1 # numpydoc ignore=GL08 @@ -143,41 +138,41 @@ def test_transform_rv_validation(): test_transforms = [t.IdentityTransform(), t.ExpTransform()] for tr in test_transforms: - my_rv = TransformedRandomVariable("test_transformed_rv", base_rv, tr) + my_rv = TransformedVariable("test_transformed_rv", base_rv, tr) assert isinstance(my_rv.transforms, tuple) assert len(my_rv.transforms) == 1 assert my_rv.sample_length() == 1 not_callable_err = "All entries in self.transforms " "must be callable" sample_length_err = "There must be exactly as many transformations" with pytest.raises(ValueError, match=sample_length_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_too_many_transforms", base_rv, (tr, tr) ) with pytest.raises(ValueError, match=sample_length_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_too_few_transforms", l2_rv, tr ) with pytest.raises(ValueError, match=sample_length_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_also_error_due_to_too_few_transforms", l2_rv, (tr,) ) with pytest.raises(ValueError, match=not_callable_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_not_callable", l2_rv, (1,) ) with pytest.raises(ValueError, match=not_callable_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_not_callable", base_rv, (1,) ) def test_transforms_applied_at_sampling(): """ - Test that TransformedRandomVariable + Test that TransformedVariable instances correctly apply their specified transformations at sampling """ - norm_rv = DistributionalRV( + norm_rv = DistributionalVariable( name="test_normal", distribution=dist.Normal(0, 1) ) norm_rv.sample_length = lambda: 1 @@ -190,9 +185,9 @@ def test_transforms_applied_at_sampling(): t.ExpTransform().inv, t.ScaledLogitTransform(5), ]: - tr_norm = TransformedRandomVariable("transformed_normal", norm_rv, tr) + tr_norm = TransformedVariable("transformed_normal", norm_rv, tr) - tr_l2 = TransformedRandomVariable( + tr_l2 = TransformedVariable( "transformed_length_2", l2_rv, (tr, t.ExpTransform()) ) @@ -217,22 +212,24 @@ def test_transforms_applied_at_sampling(): def test_transforms_variable_naming(): """ - Tests TransformedRandomVariable name + Tests TransformedVariable name recording is as expected. """ - transformed_dist_named_base_rv = TransformedRandomVariable( + transformed_dist_named_base_rv = TransformedVariable( "transformed_rv", NamedBaseRV(), (t.ExpTransform(), t.IdentityTransform()), ) - transformed_dist_unnamed_base_rv = TransformedRandomVariable( + transformed_dist_unnamed_base_rv = TransformedVariable( "transformed_rv", - DistributionalRV(name="my_normal", distribution=dist.Normal(0, 1)), + DistributionalVariable( + name="my_normal", distribution=dist.Normal(0, 1) + ), (t.ExpTransform(), t.IdentityTransform()), ) - transformed_dist_unnamed_base_l2_rv = TransformedRandomVariable( + transformed_dist_unnamed_base_l2_rv = TransformedVariable( "transformed_rv", LengthTwoRV(), (t.ExpTransform(), t.IdentityTransform()), diff --git a/test/utils.py b/test/utils.py index be551dfe..ac345b41 100644 --- a/test/utils.py +++ b/test/utils.py @@ -7,13 +7,9 @@ import numpyro.distributions as dist import pyrenew.transformation as t -from pyrenew.metaclass import ( - DistributionalRV, - RandomVariable, - SampledValue, - TransformedRandomVariable, -) +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.process import RandomWalk +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable class SimpleRt(RandomVariable): @@ -37,17 +33,17 @@ def __init__(self, name: str = "Rt_rv"): None """ self.name = name - self.rt_rv_ = TransformedRandomVariable( + self.rt_rv_ = TransformedVariable( name=f"{name}_log_rt_random_walk", base_rv=RandomWalk( name="log_rt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - self.rt_init_rv_ = DistributionalRV( + self.rt_init_rv_ = DistributionalVariable( name=f"{name}_init_log_rt", distribution=dist.Normal(0, 0.2) )