diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index b5ce439d..2e324a08 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -132,7 +132,7 @@ rt_proc = RtRandomWalkProcess( latent_infections = Infections() # (5) The observed infections process (with mean at the latent infections) -observation_process = PoissonObservation() +observation_process = PoissonObservation("poisson_rv") ``` With these five pieces, we can build the basic renewal model as an instance of the `RtInfectionsRenewalModel` class: diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 5f2c9a98..f550ea98 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -176,7 +176,15 @@ rtproc = process.RtRandomWalkProcess( ) # The observation model -obs = observation.NegativeBinomialObservation(concentration_prior=1.0) +obs = observation.NegativeBinomialObservation( + "negbinom_rv", + metaclass.DistributionalRV( + dist.TransformedDistribution( + dist.HalfNormal(), transformation.PowerTransform(-2) + ), + "concentration", + ), +) ``` Notice all the components are `RandomVariable` instances. We can now build the model: diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index 73daa493..858b41a6 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -124,7 +124,7 @@ latent_admissions = HospitalAdmissions( ) # 5) An observation process for the hospital admissions -admissions_process = PoissonObservation() +admissions_process = PoissonObservation("poisson_rv") # 6) A random walk process (it could be deterministic using # pyrenew.process.DeterministicProcess()) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index 48710592..3eb21d0f 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -3,8 +3,6 @@ from __future__ import annotations -import numbers as nums - import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike @@ -16,9 +14,8 @@ class NegativeBinomialObservation(RandomVariable): def __init__( self, - concentration_prior: dist.Distribution | ArrayLike, - concentration_suffix: str | None = "_concentration", - parameter_name="negbinom_rv", + name: str, + concentration_rv: RandomVariable, eps: float = 1e-10, ) -> None: """ @@ -26,17 +23,15 @@ def __init__( Parameters ---------- - concentration_prior : dist.Distribution | numbers.nums - Numpyro distribution from which to sample the positive concentration + name : str + Name for the numpyro variable. + concentration : RandomVariable + Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion. - concentration_suffix : str | None, optional - Suffix for the numpy variable. Defaults to "_concentration". - parameter_name : str, optional - Name for the numpy variable. Defaults to "negbinom_rv". eps : float, optional Small value to add to the predicted mean to prevent numerical instability. Defaults to 1e-10. @@ -46,25 +41,34 @@ def __init__( None """ - NegativeBinomialObservation.validate(concentration_prior) - - if isinstance(concentration_prior, dist.Distribution): - self.sample_prior = lambda: numpyro.sample( - self.parameter_name + self.concentration_suffix, - concentration_prior, - ) - else: - self.sample_prior = lambda: concentration_prior + NegativeBinomialObservation.validate(concentration_rv) - self.parameter_name = parameter_name - self.concentration_suffix = concentration_suffix + self.name = name + self.concentration_rv = concentration_rv self.eps = eps + @staticmethod + def validate(concentration_rv: RandomVariable) -> None: + """ + Check that the concentration_rv is actually a RandomVariable + + Parameters + ---------- + concentration_rv : any + RandomVariable from which to sample the positive concentration + parameter of the negative binomial. + + Returns + ------- + None + """ + assert isinstance(concentration_rv, RandomVariable) + return None + def sample( self, mu: ArrayLike, obs: ArrayLike | None = None, - name: str | None = None, **kwargs, ) -> tuple: """ @@ -76,9 +80,6 @@ def sample( Mean parameter of the negative binomial distribution. obs : ArrayLike, optional Observed data, by default None. - name : str, optional - Name of the random variable if other than that defined during - construction, by default None (self.parameter_name). **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. @@ -86,14 +87,11 @@ def sample( ------- tuple """ - concentration = self.sample_prior() + concentration, *_ = self.concentration_rv.sample() - if name is None: - name = self.parameter_name - - return ( + negative_binomial_sample = ( numpyro.sample( - name=name, + name=self.name, fn=dist.NegativeBinomial2( mean=mu + self.eps, concentration=concentration, @@ -101,24 +99,4 @@ def sample( obs=obs, ), ) - - @staticmethod - def validate(concentration_prior: any) -> None: - """ - Check that the concentration prior is actually a nums.Number - - Parameters - ---------- - concentration_prior : any - Numpyro distribution from which to sample the positive concentration - parameter of the negative binomial. Expected dist.Distribution or - numbers.nums - - Returns - ------- - None - """ - assert isinstance( - concentration_prior, (dist.Distribution, nums.Number) - ) - return None + return (negative_binomial_sample,) diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index c641cf76..4f378197 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -16,7 +16,7 @@ class PoissonObservation(RandomVariable): def __init__( self, - parameter_name: str = "poisson_rv", + name: str, eps: float = 1e-8, ) -> None: """ @@ -24,8 +24,8 @@ def __init__( Parameters ---------- - parameter_name : str, optional - Passed to numpyro.sample. Defaults to "poisson_rv" + name : str, optional + Passed to numpyro.sample. eps : float, optional Small value added to the rate parameter to avoid zero values. Defaults to 1e-8. @@ -35,16 +35,19 @@ def __init__( None """ - self.parameter_name = parameter_name + self.name = name self.eps = eps return None + @staticmethod + def validate(): # numpydoc ignore=GL08 + None + def sample( self, mu: ArrayLike, obs: ArrayLike | None = None, - name: str | None = None, **kwargs, ) -> tuple: """ @@ -56,8 +59,6 @@ def sample( Rate parameter of the Poisson distribution. obs : ArrayLike | None, optional Observed data. Defaults to None. - name : str | None, optional - Name of the random variable. Defaults to None. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. @@ -66,17 +67,9 @@ def sample( tuple """ - if name is None: - name = self.parameter_name - - return ( - numpyro.sample( - name=name, - fn=dist.Poisson(rate=mu + self.eps), - obs=obs, - ), + poisson_sample = numpyro.sample( + name=self.name, + fn=dist.Poisson(rate=mu + self.eps), + obs=obs, ) - - @staticmethod - def validate(): # numpydoc ignore=GL08 - None + return (poisson_sample,) diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 90d4b2fd..523297b3 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -30,7 +30,7 @@ def test_forecast(): t_unit=1, ) latent_infections = Infections() - observed_infections = PoissonObservation() + observed_infections = PoissonObservation("poisson_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), Rt_transform=t.ExpTransform().inv, diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 22c5d992..d0961eb3 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -35,7 +35,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): latent_infections = Infections() - observed_infections = PoissonObservation() + observed_infections = PoissonObservation("poisson_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), @@ -72,7 +72,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): latent_infections = Infections() - observed_infections = PoissonObservation() + observed_infections = PoissonObservation("possion_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), @@ -195,7 +195,7 @@ def test_model_basicrenewal_with_obs_model(): latent_infections = Infections() - observed_infections = PoissonObservation() + observed_infections = PoissonObservation("poisson_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), @@ -249,7 +249,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 latent_infections = Infections() - observed_infections = PoissonObservation() + observed_infections = PoissonObservation("poisson_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 056b1798..7eb6a8c4 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -60,7 +60,7 @@ def test_model_hosp_no_timepoints_or_observations(): Rt_transform=t.ExpTransform().inv, Rt_rw_dist=dist.Normal(0, 0.025), ) - observed_admissions = PoissonObservation() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( jnp.array( @@ -129,7 +129,7 @@ def test_model_hosp_both_timepoints_and_observations(): Rt_transform=t.ExpTransform().inv, Rt_rw_dist=dist.Normal(0, 0.025), ) - observed_admissions = PoissonObservation() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( jnp.array( @@ -315,7 +315,7 @@ def test_model_hosp_with_obs_model(): Rt_transform=t.ExpTransform().inv, Rt_rw_dist=dist.Normal(0, 0.025), ) - observed_admissions = PoissonObservation() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( jnp.array( @@ -405,7 +405,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): Rt_transform=t.ExpTransform().inv, Rt_rw_dist=dist.Normal(0, 0.025), ) - observed_admissions = PoissonObservation() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( jnp.array( @@ -508,7 +508,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): Rt_transform=t.ExpTransform().inv, Rt_rw_dist=dist.Normal(0, 0.025), ) - observed_admissions = PoissonObservation() + observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( jnp.array( diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index 2f389378..3ff66ad3 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -4,6 +4,7 @@ import numpy as np import numpy.testing as testing import numpyro as npro +from pyrenew.deterministic import DeterministicVariable from pyrenew.observation import NegativeBinomialObservation @@ -12,7 +13,10 @@ def test_negativebinom_deterministic_obs(): Check that a deterministic NegativeBinomialObservation can sample """ - negb = NegativeBinomialObservation(concentration_prior=10) + negb = NegativeBinomialObservation( + "negbinom_rv", + concentration_rv=DeterministicVariable(10, name="concentration"), + ) np.random.seed(223) rates = np.random.randint(1, 5, size=10) @@ -31,7 +35,10 @@ def test_negativebinom_random_obs(): Check that a random NegativeBinomialObservation can sample """ - negb = NegativeBinomialObservation(concentration_prior=10) + negb = NegativeBinomialObservation( + "negbinom_rv", + concentration_rv=DeterministicVariable(10, "concentration"), + ) np.random.seed(223) rates = np.repeat(5, 20000) diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 693cf638..1089974e 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -28,7 +28,7 @@ t_unit=1, ) latent_infections = Infections() -observed_infections = PoissonObservation() +observed_infections = PoissonObservation("poisson_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), Rt_transform=t.ExpTransform().inv, diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index c181dfbd..5f1c9986 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -34,7 +34,7 @@ def create_test_model(): # numpydoc ignore=GL08 t_unit=1, ) latent_infections = Infections() - observed_infections = PoissonObservation() + observed_infections = PoissonObservation("poisson_rv") rt = RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), Rt_transform=t.ExpTransform().inv,