Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make concentration in NegativeBinomialObservation a RandomVariable #267

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ rt_proc = RtRandomWalkProcess(
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observation_process = PoissonObservation()
observation_process = PoissonObservation("poisson_rv")
```

With these five pieces, we can build the basic renewal model as an instance of the `RtInfectionsRenewalModel` class:
Expand Down
10 changes: 9 additions & 1 deletion docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,15 @@ rtproc = process.RtRandomWalkProcess(
)

# The observation model
obs = observation.NegativeBinomialObservation(concentration_prior=1.0)
obs = observation.NegativeBinomialObservation(
"negbinom_rv",
metaclass.DistributionalRV(
dist.TransformedDistribution(
dist.HalfNormal(), transformation.PowerTransform(-2)
),
"concentration",
),
)
```

Notice all the components are `RandomVariable` instances. We can now build the model:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ latent_admissions = HospitalAdmissions(
)

# 5) An observation process for the hospital admissions
admissions_process = PoissonObservation()
admissions_process = PoissonObservation("poisson_rv")

# 6) A random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Expand Down
84 changes: 31 additions & 53 deletions model/src/pyrenew/observation/negativebinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from __future__ import annotations

import numbers as nums

import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike
Expand All @@ -16,27 +14,24 @@ class NegativeBinomialObservation(RandomVariable):

def __init__(
self,
concentration_prior: dist.Distribution | ArrayLike,
concentration_suffix: str | None = "_concentration",
parameter_name="negbinom_rv",
name: str,
concentration_rv: RandomVariable,
eps: float = 1e-10,
) -> None:
"""
Default constructor

Parameters
----------
concentration_prior : dist.Distribution | numbers.nums
Numpyro distribution from which to sample the positive concentration
name : str
Name for the numpyro variable.
concentration : RandomVariable
Random variable from which to sample the positive concentration
parameter of the negative binomial. This parameter is sometimes
called k, phi, or the "dispersion" or "overdispersion" parameter,
despite the fact that larger values imply that the distribution
becomes more Poissonian, while smaller ones imply a greater degree
of dispersion.
concentration_suffix : str | None, optional
Suffix for the numpy variable. Defaults to "_concentration".
parameter_name : str, optional
Name for the numpy variable. Defaults to "negbinom_rv".
eps : float, optional
Small value to add to the predicted mean to prevent numerical
instability. Defaults to 1e-10.
Expand All @@ -46,25 +41,34 @@ def __init__(
None
"""

NegativeBinomialObservation.validate(concentration_prior)

if isinstance(concentration_prior, dist.Distribution):
self.sample_prior = lambda: numpyro.sample(
self.parameter_name + self.concentration_suffix,
concentration_prior,
)
else:
self.sample_prior = lambda: concentration_prior
NegativeBinomialObservation.validate(concentration_rv)

self.parameter_name = parameter_name
self.concentration_suffix = concentration_suffix
self.name = name
self.concentration_rv = concentration_rv
self.eps = eps

@staticmethod
def validate(concentration_rv: RandomVariable) -> None:
"""
Check that the concentration_rv is actually a RandomVariable

Parameters
----------
concentration_rv : any
RandomVariable from which to sample the positive concentration
parameter of the negative binomial.

Returns
-------
None
"""
assert isinstance(concentration_rv, RandomVariable)
return None

def sample(
self,
mu: ArrayLike,
obs: ArrayLike | None = None,
name: str | None = None,
**kwargs,
) -> tuple:
"""
Expand All @@ -76,49 +80,23 @@ def sample(
Mean parameter of the negative binomial distribution.
obs : ArrayLike, optional
Observed data, by default None.
name : str, optional
Name of the random variable if other than that defined during
construction, by default None (self.parameter_name).
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample calls, should there be any.

Returns
-------
tuple
"""
concentration = self.sample_prior()
concentration, *_ = self.concentration_rv.sample()

if name is None:
name = self.parameter_name

return (
negative_binomial_sample = (
numpyro.sample(
name=name,
name=self.name,
fn=dist.NegativeBinomial2(
mean=mu + self.eps,
concentration=concentration,
),
obs=obs,
),
)

@staticmethod
def validate(concentration_prior: any) -> None:
"""
Check that the concentration prior is actually a nums.Number

Parameters
----------
concentration_prior : any
Numpyro distribution from which to sample the positive concentration
parameter of the negative binomial. Expected dist.Distribution or
numbers.nums

Returns
-------
None
"""
assert isinstance(
concentration_prior, (dist.Distribution, nums.Number)
)
return None
return (negative_binomial_sample,)
33 changes: 13 additions & 20 deletions model/src/pyrenew/observation/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

def __init__(
self,
parameter_name: str = "poisson_rv",
name: str,
eps: float = 1e-8,
) -> None:
"""
Default Constructor

Parameters
----------
parameter_name : str, optional
Passed to numpyro.sample. Defaults to "poisson_rv"
name : str, optional
Passed to numpyro.sample.
eps : float, optional
Small value added to the rate parameter to avoid zero values.
Defaults to 1e-8.
Expand All @@ -35,16 +35,19 @@
None
"""

self.parameter_name = parameter_name
self.name = name
self.eps = eps

return None

@staticmethod
def validate(): # numpydoc ignore=GL08
None

Check warning on line 45 in model/src/pyrenew/observation/poisson.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/observation/poisson.py#L45

Added line #L45 was not covered by tests

def sample(
self,
mu: ArrayLike,
obs: ArrayLike | None = None,
name: str | None = None,
**kwargs,
) -> tuple:
"""
Expand All @@ -56,8 +59,6 @@
Rate parameter of the Poisson distribution.
obs : ArrayLike | None, optional
Observed data. Defaults to None.
name : str | None, optional
Name of the random variable. Defaults to None.
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample calls, should there be any.

Expand All @@ -66,17 +67,9 @@
tuple
"""

if name is None:
name = self.parameter_name

return (
numpyro.sample(
name=name,
fn=dist.Poisson(rate=mu + self.eps),
obs=obs,
),
poisson_sample = numpyro.sample(
name=self.name,
fn=dist.Poisson(rate=mu + self.eps),
obs=obs,
)

@staticmethod
def validate(): # numpydoc ignore=GL08
None
return (poisson_sample,)
2 changes: 1 addition & 1 deletion model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_forecast():
t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation()
observed_infections = PoissonObservation("poisson_rv")
rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Expand Down
8 changes: 4 additions & 4 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_model_basicrenewal_no_timepoints_or_observations():

latent_infections = Infections()

observed_infections = PoissonObservation()
observed_infections = PoissonObservation("poisson_rv")

rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_model_basicrenewal_both_timepoints_and_observations():

latent_infections = Infections()

observed_infections = PoissonObservation()
observed_infections = PoissonObservation("possion_rv")

rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_model_basicrenewal_with_obs_model():

latent_infections = Infections()

observed_infections = PoissonObservation()
observed_infections = PoissonObservation("poisson_rv")

rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08

latent_infections = Infections()

observed_infections = PoissonObservation()
observed_infections = PoissonObservation("poisson_rv")

rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Expand Down
10 changes: 5 additions & 5 deletions model/src/test/test_model_hospitalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_model_hosp_no_timepoints_or_observations():
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
observed_admissions = PoissonObservation()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
jnp.array(
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_model_hosp_both_timepoints_and_observations():
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
observed_admissions = PoissonObservation()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
jnp.array(
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_model_hosp_with_obs_model():
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
observed_admissions = PoissonObservation()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
jnp.array(
Expand Down Expand Up @@ -405,7 +405,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2():
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
observed_admissions = PoissonObservation()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
jnp.array(
Expand Down Expand Up @@ -508,7 +508,7 @@ def test_model_hosp_with_obs_model_weekday_phosp():
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
observed_admissions = PoissonObservation()
observed_admissions = PoissonObservation("poisson_rv")

inf_hosp = DeterministicPMF(
jnp.array(
Expand Down
11 changes: 9 additions & 2 deletions model/src/test/test_observation_negativebinom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import numpy.testing as testing
import numpyro as npro
from pyrenew.deterministic import DeterministicVariable
from pyrenew.observation import NegativeBinomialObservation


Expand All @@ -12,7 +13,10 @@ def test_negativebinom_deterministic_obs():
Check that a deterministic NegativeBinomialObservation can sample
"""

negb = NegativeBinomialObservation(concentration_prior=10)
negb = NegativeBinomialObservation(
"negbinom_rv",
concentration_rv=DeterministicVariable(10, name="concentration"),
)

np.random.seed(223)
rates = np.random.randint(1, 5, size=10)
Expand All @@ -31,7 +35,10 @@ def test_negativebinom_random_obs():
Check that a random NegativeBinomialObservation can sample
"""

negb = NegativeBinomialObservation(concentration_prior=10)
negb = NegativeBinomialObservation(
"negbinom_rv",
concentration_rv=DeterministicVariable(10, "concentration"),
)

np.random.seed(223)
rates = np.repeat(5, 20000)
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation()
observed_infections = PoissonObservation("poisson_rv")
rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_random_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def create_test_model(): # numpydoc ignore=GL08
t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation()
observed_infections = PoissonObservation("poisson_rv")
rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Expand Down