Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic and static distributional rvs #391

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", dist=dist.LogNormal(2.5, 1)),
DistributionalRV(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -148,12 +148,12 @@ class MyRt(RandomVariable):
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv",
dist=dist.Normal(0, sd_rt),
distribution=dist.Normal(0, sd_rt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
name="init_log_rt",
dist=dist.Normal(jnp.log(1), jnp.log(1.2)),
distribution=dist.Normal(jnp.log(1), jnp.log(1.2)),
),
),
transforms=t.ExpTransform(),
Expand Down
8 changes: 5 additions & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)

I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)),
DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.05),
Expand All @@ -67,9 +67,11 @@ rt = TransformedRandomVariable(
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
Expand Down
13 changes: 7 additions & 6 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ inf_hosp_int = deterministic.DeterministicPMF(
)

hosp_rate = metaclass.DistributionalRV(
name="IHR", dist=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)

latent_hosp = latent.HospitalAdmissions(
Expand All @@ -171,7 +171,8 @@ latent_inf = latent.Infections()
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0", dist=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75))
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
Expand Down Expand Up @@ -199,10 +200,10 @@ class MyRt(metaclass.RandomVariable):
base_rv=process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=metaclass.DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, sd_rt.value)
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
name="init_log_rt", distribution=dist.Normal(0, 0.2)
),
),
transforms=transformation.ExpTransform(),
Expand All @@ -213,7 +214,7 @@ class MyRt(metaclass.RandomVariable):

rtproc = MyRt(
metaclass.DistributionalRV(
name="Rt_random_walk_sd", dist=dist.HalfNormal(0.025)
name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025)
)
)

Expand All @@ -225,7 +226,7 @@ nb_conc_rv = metaclass.TransformedRandomVariable(
"concentration",
metaclass.DistributionalRV(
name="concentration_raw",
dist=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mysimplex = dist.TransformedDistribution(
dayofweek = process.DayOfWeekEffect(
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(
name="simp", dist=mysimplex
name="simp", distribution=mysimplex
),
t_start=0,
)
Expand Down
198 changes: 182 additions & 16 deletions model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
"""

from abc import ABCMeta, abstractmethod
from typing import NamedTuple, get_type_hints
from typing import Callable, NamedTuple, get_type_hints

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import polars as pl
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive
Expand Down Expand Up @@ -126,7 +127,7 @@

class SampledValue(NamedTuple):
"""
A container for a sampled value from a RandomVariable.
A container for a value sampled from a RandomVariable.

Attributes
----------
Expand All @@ -135,7 +136,8 @@
t_start : int, optional
The start time of the value.
t_unit : int, optional
The unit of time relative to the model's fundamental (smallest) time unit.
The unit of time relative to the model's fundamental
(smallest) time unit.
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
"""

value: ArrayLike | None = None
Expand Down Expand Up @@ -274,16 +276,127 @@
return self.sample(**kwargs)


class DistributionalRV(RandomVariable):
class DynamicDistributionalRV(RandomVariable):
"""
Wrapper class for random variables that sample
from a single :class:`numpyro.distributions.Distribution`.
from a single :class:`numpyro.distributions.Distribution`
that is parameterized / instantiated at `sample()` time
(rather than at RandomVariable instantiation time).
"""

def __init__(
self,
name: str,
dist: numpyro.distributions.Distribution,
distribution_constructor: Callable,
reparam: Reparam = None,
) -> None:
"""
Default constructor for DynamicDistributionalRV.

Parameters
----------
name : str
Name of the random variable.
distribution_constructor : Callable
Callable that returns a concrete parametrized
numpyro.Distributions.distribution instance.
reparam : numpyro.infer.reparam.Reparam
If not None, reparameterize sampling
from the distribution according to the
given numpyro reparameterizer

Returns
-------
None
"""

self.name = name
self.validate(distribution_constructor)
self.distribution_constructor = distribution_constructor
if reparam is not None:
self.reparam_dict = {self.name: reparam}

Check warning on line 317 in model/src/pyrenew/metaclass.py

View check run for this annotation

Codecov / codecov/patch

model/src/pyrenew/metaclass.py#L317

Added line #L317 was not covered by tests
else:
self.reparam_dict = {}

return None

@staticmethod
def validate(distribution_constructor: any) -> None:
"""
Confirm that the distribution_constructor is
callable.

Parameters
----------
distribution_constructor : any
Putative distribution_constructor to validate.

Returns
-------
None or raises a ValueError
"""
if not callable(distribution_constructor):
raise ValueError(
"To instantiate a DynamicDistributionalRV, ",
"one must provide a Callable that returns a "
"numpyro.distributions.Distribution as the "
"distribution_constructor argument. "
f"Got {type(distribution_constructor)}, which "
"does not appear to be callable",
)
return None

def sample(
self,
*args,
obs: ArrayLike = None,
**kwargs,
) -> tuple:
"""
Sample from the distributional rv.

Parameters
----------
*args :
Positional arguments passed to self.distribution_constructor
obs : ArrayLike, optional
Observations passed as the `obs` argument to
:fun:`numpyro.sample()`. Default `None`.
**kwargs : dict, optional
Keyword arguments passed to self.distribution_constructor

Returns
-------
SampledValue
Containing a sample from the distribution.
"""
with numpyro.handlers.reparam(config=self.reparam_dict):
sample = numpyro.sample(
name=self.name,
fn=self.distribution_constructor(*args, **kwargs),
obs=obs,
)
return (
SampledValue(
jnp.atleast_1d(sample),
t_start=self.t_start,
t_unit=self.t_unit,
),
)


class StaticDistributionalRV(RandomVariable):
"""
Wrapper class for random variables that sample
from a single :class:`numpyro.distributions.Distribution`
that is parameterized / instantiated at RandomVariable
instantiation time (rather than at `sample()`-ing time).
"""

def __init__(
self,
name: str,
distribution: numpyro.distributions.Distribution,
reparam: Reparam = None,
) -> None:
"""
Expand All @@ -293,7 +406,7 @@
----------
name : str
Name of the random variable.
dist : numpyro.distributions.Distribution
distribution : numpyro.distributions.Distribution
Distribution of the random variable.
reparam : numpyro.infer.reparam.Reparam
If not None, reparameterize sampling
Expand All @@ -306,8 +419,8 @@
"""

self.name = name
self.validate(dist)
self.dist = dist
self.validate(distribution)
self.distribution = distribution
if reparam is not None:
self.reparam_dict = {self.name: reparam}
else:
Expand All @@ -316,14 +429,15 @@
return None

@staticmethod
def validate(dist: any) -> None:
def validate(distribution: any) -> None:
"""
Validation of the distribution to be implemented in subclasses.
"""
if not isinstance(dist, numpyro.distributions.Distribution):
if not isinstance(distribution, numpyro.distributions.Distribution):
raise ValueError(
"dist should be an instance of "
f"numpyro.distributions.Distribution, got {dist}"
"distribution should be an instance of "
"numpyro.distributions.Distribution, got "
"{type(distribution)}"
)

return None
Expand All @@ -347,13 +461,13 @@

Returns
-------
tuple
Containing the sampled from the distribution.
SampledValue
Containing a sample from the distribution.
"""
with numpyro.handlers.reparam(config=self.reparam_dict):
sample = numpyro.sample(
name=self.name,
fn=self.dist,
fn=self.distribution,
obs=obs,
)
return (
Expand All @@ -365,6 +479,58 @@
)


def DistributionalRV(
name: str,
distribution: numpyro.distributions.Distribution | Callable,
reparam: Reparam = None,
) -> RandomVariable:
"""
Factory function to generate Distributional RandomVariables,
either static or dynamic.

Parameters
----------
name : str
Name of the random variable.

distribution: numpyro.distributions.Distribution | Callable
Either numpyro.distributions.Distribution instance
given the static distribution of the random variable or
a callable that returns a parameterized
numpyro.distributions.Distribution when called, which
allows for dynamically-parameterized DistributionalRVs,
e.g. a Normal distribution with an inferred location and
scale.

reparam : numpyro.infer.reparam.Reparam
If not None, reparameterize sampling
from the distribution according to the
given numpyro reparameterizer

Returns
-------
DynamicDistributionalRV | StaticDistributionalRV or
raises a ValueError if a distribution cannot be constructed.
"""
if isinstance(distribution, dist.Distribution):
return StaticDistributionalRV(
name=name, distribution=distribution, reparam=reparam
)
elif callable(distribution):
return DynamicDistributionalRV(
name=name, distribution_constructor=distribution, reparam=reparam
)
else:
raise ValueError(
"distribution argument to DistributionalRV "
"must be either a numpyro.distributions.Distribution "
"(for instantiating a static DistributionalRV) "
"or a callable that returns a "
"numpyro.distributions.Distribution (for "
"a dynamic DistributionalRV"
)


class Model(metaclass=ABCMeta):
"""Abstract base class for models"""

Expand Down
2 changes: 1 addition & 1 deletion model/src/test/test_assert_sample_and_rtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_input_rv(): # numpydoc ignore=GL08
valid_rv = [
NullObservation(),
DeterministicVariable(name="rv1", value=jnp.array([1, 2, 3, 4])),
DistributionalRV(name="rv2", dist=dist.Normal(0, 1)),
DistributionalRV(name="rv2", distribution=dist.Normal(0, 1)),
]
not_rv = jnp.array([1])

Expand Down
Loading