Skip to content

Commit

Permalink
Pseudo-Random Number Generation For None Input For JAX PRNGKey In Met…
Browse files Browse the repository at this point in the history
…aclass (#192)

* one way of psuedo random key

* change type to ArrayLike

* add numpy

* add None case in tests and add jax random key instead of PRNGKey

* test for rng_key during run

* rename misspelled file

* fix tests, adapt names, fix PRNGKey

* remove added file

* Update model/src/pyrenew/metaclass.py

Clever and nice.

Co-authored-by: Damon Bayer <[email protected]>

* DRY a bit less for tests

---------

Co-authored-by: Damon Bayer <[email protected]>
  • Loading branch information
AFg6K7h4fhy2 and damonbayer authored Jun 17, 2024
1 parent 657a206 commit 407bc4e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 11 deletions.
9 changes: 8 additions & 1 deletion model/src/pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro as npro
import polars as pl
from jax.typing import ArrayLike
Expand Down Expand Up @@ -312,7 +314,7 @@ def run(
self,
num_warmup,
num_samples,
rng_key: jax.random.PRNGKey = jax.random.PRNGKey(54),
rng_key: ArrayLike | None = None,
nuts_args: dict = None,
mcmc_args: dict = None,
**kwargs,
Expand All @@ -339,6 +341,11 @@ def run(
nuts_args=nuts_args,
mcmc_args=mcmc_args,
)
if rng_key is None:
rand_int = np.random.randint(
np.iinfo(np.int64).min, np.iinfo(np.int64).max
)
rng_key = jr.key(rand_int)

self.mcmc.run(rng_key=rng_key, **kwargs)

Expand Down
11 changes: 6 additions & 5 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

import jax

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro as npro
Expand Down Expand Up @@ -159,7 +160,7 @@ def test_model_basicrenewal_no_obs_model():
model0.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
rng_key=jr.key(272),
data_observed_infections=model0_samp.latent_infections,
)

Expand Down Expand Up @@ -217,7 +218,7 @@ def test_model_basicrenewal_with_obs_model():
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(22),
rng_key=jr.key(22),
data_observed_infections=model1_samp.observed_infections,
)

Expand Down Expand Up @@ -290,7 +291,7 @@ def test_model_basicrenewal_plot() -> plt.Figure:
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(22),
rng_key=jr.key(22),
data_observed_infections=model1_samp.observed_infections,
)

Expand Down Expand Up @@ -341,7 +342,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(22),
rng_key=jr.key(22),
data_observed_infections=new_obs,
padding=5,
)
Expand Down
11 changes: 6 additions & 5 deletions model/src/test/test_model_hospitalizations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

import jax

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
Expand Down Expand Up @@ -275,7 +276,7 @@ def test_model_hosp_no_obs_model():
model0.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
rng_key=jr.key(272),
data_observed_hosp_admissions=model0_samp.latent_hosp_admissions,
)

Expand Down Expand Up @@ -364,7 +365,7 @@ def test_model_hosp_with_obs_model():
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
rng_key=jr.key(272),
data_observed_hosp_admissions=model1_samp.observed_hosp_admissions,
)

Expand Down Expand Up @@ -464,7 +465,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2():
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
rng_key=jr.key(272),
data_observed_hosp_admissions=model1_samp.observed_hosp_admissions,
)

Expand Down Expand Up @@ -581,7 +582,7 @@ def test_model_hosp_with_obs_model_weekday_phosp():
model1.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.PRNGKey(272),
rng_key=jr.key(272),
data_observed_hosp_admissions=obs,
padding=5,
)
Expand Down
113 changes: 113 additions & 0 deletions model/src/test/test_random_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-

"""
Ensures that models created with the same or
with different random keys behave appropriately.
"""

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
Infections,
InfectionSeedingProcess,
SeedInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import RtRandomWalkProcess


def create_test_model(): # numpydoc ignore=GL08
pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
I0 = InfectionSeedingProcess(
"I0_seeding",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
SeedInfectionsZeroPad(n_timepoints=gen_int.size()),
)
latent_infections = Infections()
observed_infections = PoissonObservation()
rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
model = RtInfectionsRenewalModel(
I0_rv=I0,
gen_int_rv=gen_int,
latent_infections_rv=latent_infections,
infection_obs_process_rv=observed_infections,
Rt_process_rv=rt,
)
return model


def sample_test_model(
test_model, observed_infections, rng_key
): # numpydoc ignore=GL08
test_model.run(
num_warmup=50,
num_samples=50,
data_observed_infections=observed_infections,
rng_key=rng_key,
mcmc_args=dict(progress_bar=True),
)


def test_rng_keys_produce_correct_samples():
"""
Tests that the random keys specified for
MCMC sampling produce appropriate
output if left to None or specified directly.
"""

# set up singular epidemiological process

# set up base models for testing
model_01 = create_test_model()
model_02 = create_test_model()
model_03 = create_test_model()
model_04 = create_test_model()
model_05 = create_test_model()

# sample only a single model and use that model's samples
# as the observed_infections for the rest of the models
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model_01_samp = model_01.sample(n_timepoints_to_simulate=30)

# run test models with the different keys
models = [model_01, model_02, model_03, model_04, model_05]
rng_keys = [jr.key(54), jr.key(54), None, None, jr.key(74)]
obs_infections = [model_01_samp.observed_infections] * len(models)
for elt in list(zip(models, obs_infections, rng_keys)):
sample_test_model(*elt)

# using same rng_key should get same run samples
assert np.array_equal(
model_01.mcmc.get_samples()["Rt"][0],
model_02.mcmc.get_samples()["Rt"][0],
)

# using None for rng_key should get different run samples
assert not np.array_equal(
model_03.mcmc.get_samples()["Rt"][0],
model_04.mcmc.get_samples()["Rt"][0],
)

# using None vs preselected rng_key should get different samples
assert not np.array_equal(
model_01.mcmc.get_samples()["Rt"][0],
model_03.mcmc.get_samples()["Rt"][0],
)

# using two different non-None rng keys should get different samples
assert not np.array_equal(
model_02.mcmc.get_samples()["Rt"][0],
model_05.mcmc.get_samples()["Rt"][0],
)

0 comments on commit 407bc4e

Please sign in to comment.