diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 1efb8f2e..dc67ded7 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -9,9 +9,9 @@ package-mode = false [tool.poetry.dependencies] python = "^3.12" sphinx = "^7.2.6" -jax = "^0.4.25" -jaxlib = "^0.4.25" -numpyro = "^0.15.0" +jax = ">=0.4.30" +jaxlib = ">=0.4.30" +numpyro = ">=0.15.1" sphinxcontrib-mermaid = "^0.9.2" polars = "^0.20.16" matplotlib = "^3.8.3" diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 3fc88342..e2433c99 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -214,7 +214,9 @@ model1.run( num_samples=1000, data_observed_infections=sim_data.observed_infections, rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"), + mcmc_args=dict( + progress_bar=False, num_chains=2, chain_method="sequential" + ), ) ``` @@ -335,7 +337,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1) +mean_latent_infection = np.mean( + idata.posterior["all_latent_infections"], axis=1 +) axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index f15e1c5b..f88a884e 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -81,11 +81,12 @@ Let's take a look at the daily prevalence of hospital admissions. # | fig-cap: Daily hospital admissions from the simulated data import matplotlib.pyplot as plt +daily_hosp_admits = dat["daily_hosp_admits"].to_numpy() # Rotating the x-axis labels, and only showing ~10 labels ax = plt.gca() ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10)) ax.xaxis.set_tick_params(rotation=45) -plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy()) +plt.plot(dat["date"].to_numpy(), daily_hosp_admits) plt.xlabel("Date") plt.ylabel("Admissions") plt.show() @@ -147,7 +148,10 @@ The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to ho ```{python} # | label: initializing-rest-of-model from pyrenew import model, process, observation, metaclass, transformation -from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponentialGrowth +from pyrenew.latent import ( + InfectionSeedingProcess, + SeedInfectionsExponentialGrowth, +) # Infection process latent_inf = latent.Infections() @@ -237,9 +241,11 @@ npro.set_host_device_count(jax.local_device_count()) hosp_model.run( num_samples=1000, num_warmup=1000, - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), + data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"), + mcmc_args=dict( + progress_bar=False, num_chains=2, chain_method="sequential" + ), ) ``` @@ -254,7 +260,7 @@ out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), + daily_hosp_admits.astype(float), (gen_int_array.size, 0), constant_values=np.nan, ), @@ -270,10 +276,10 @@ import arviz as az idata = az.from_numpyro( hosp_model.mcmc, posterior_predictive=hosp_model.posterior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"]) + n_timepoints_to_simulate=len(daily_hosp_admits) ), prior=hosp_model.prior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"]), + n_timepoints_to_simulate=len(daily_hosp_admits), numpyro_predictive_args={"num_samples": 1000}, ), ) @@ -308,22 +314,15 @@ We can use the padding argument to solve the overestimation of hospital admissio ```{python} # | label: model-fit-padding -days_to_impute = 21 - -# Add 21 Nas to the beginning of dat_w_padding -dat_w_padding = np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (days_to_impute, 0), - constant_values=np.nan, -) +pad_size = 21 hosp_model.run( num_samples=1000, num_warmup=1000, - data_observed_hosp_admissions=dat_w_padding, + data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), - padding=days_to_impute, # Padding the model + padding=pad_size, # Padding the model ) ``` @@ -336,7 +335,9 @@ out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( - dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan + daily_hosp_admits.astype(float), + (gen_int_array.size + pad_size, 0), + constant_values=np.nan, ), ) ``` @@ -407,7 +408,9 @@ We can look at individual draws from the posterior distribution of latent infect ```{python} # | label: fig-output-infections-with-padding # | fig-cap: Latent infections -out2 = hosp_model.plot_posterior(var="all_latent_infections", ylab="Latent Infections") +out2 = hosp_model.plot_posterior( + var="all_latent_infections", ylab="Latent Infections" +) ``` We can also look at credible intervals for the posterior distribution of latent infections: @@ -440,7 +443,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1) +mean_latent_infection = np.mean( + idata.posterior["all_latent_infections"], axis=1 +) axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) @@ -520,10 +525,10 @@ Running the model (with the same padding as before): hosp_model_weekday.run( num_samples=2000, num_warmup=2000, - data_observed_hosp_admissions=dat_w_padding, + data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), - padding=days_to_impute, + padding=pad_size, ) ``` @@ -535,21 +540,29 @@ And plotting the results: out = hosp_model_weekday.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan), + obs_signal=np.pad( + daily_hosp_admits.astype(float), + (gen_int_array.size + pad_size, 0), + constant_values=np.nan, + ), ) ``` -We will use ArviZ to visualize the posterior/prior predictive distributions. +We will use ArviZ to visualize the posterior and prior predictive distributions. +By increasing `n_timepoints_to_simulate`, we can perform forecasting using the posterior predictive distribution. ```{python} # | label: posterior-predictive-distribution +n_forecast_points = 28 idata_weekday = az.from_numpyro( hosp_model_weekday.mcmc, posterior_predictive=hosp_model_weekday.posterior_predictive( - n_timepoints_to_simulate=len(dat_w_padding) + n_timepoints_to_simulate=len(daily_hosp_admits) + n_forecast_points, + padding=pad_size, ), prior=hosp_model_weekday.prior_predictive( - n_timepoints_to_simulate=len(dat_w_padding), + n_timepoints_to_simulate=len(daily_hosp_admits), + padding=pad_size, numpyro_predictive_args={"num_samples": 1000}, ), ) @@ -569,7 +582,9 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"], + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, @@ -578,7 +593,9 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"], + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5), color="C0", smooth=False, @@ -587,7 +604,9 @@ az.plot_hdi( ) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", ) @@ -599,14 +618,18 @@ plt.yscale("log") plt.show() ``` -And now we plot the posterior predictive distributions: +And now we plot the posterior predictive distributions with a `{python} n_forecast_points`-day-ahead forecast: ```{python} # | label: fig-output-posterior-predictive # | fig-cap: Posterior Predictive Infections fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], - hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.9), + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), + hdi_data=compute_eti( + idata_weekday.posterior_predictive["negbinom_rv"], 0.9 + ), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, @@ -614,8 +637,12 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], - hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.5), + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), + hdi_data=compute_eti( + idata_weekday.posterior_predictive["negbinom_rv"], 0.5 + ), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, @@ -628,13 +655,17 @@ mean_latent_infection = np.mean( ) plt.plot( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), mean_latent_infection[0], color="C0", label="Mean", ) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", ) diff --git a/model/pyproject.toml b/model/pyproject.toml index 63917f66..fb164c2e 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -11,8 +11,9 @@ exclude = [{path = "datasets/*.rds"}] [tool.poetry.dependencies] python = "^3.12" -numpyro = "^0.15.0" -jax = "^0.4.25" +numpyro = ">=0.15.1" +jax = ">=0.4.30" +jaxlib = ">=0.4.30" numpy = "^1.26.4" polars = "^0.20.16" pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index a190bcfc..caf5bbdb 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -5,9 +5,8 @@ from typing import NamedTuple -import jax.numpy as jnp -import pyrenew.arrayutils as au from jax.typing import ArrayLike +from pyrenew.deterministic import NullObservation from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -99,6 +98,9 @@ def __init__( ) self.latent_hosp_admissions_rv = latent_hosp_admissions_rv + if hosp_admission_obs_process_rv is None: + hosp_admission_obs_process_rv = NullObservation() + self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv @staticmethod @@ -178,13 +180,13 @@ def sample( "Cannot pass both n_timepoints_to_simulate and data_observed_hosp_admissions." ) elif n_timepoints_to_simulate is None: - n_timepoints = len(data_observed_hosp_admissions) + n_datapoints = len(data_observed_hosp_admissions) else: - n_timepoints = n_timepoints_to_simulate + n_datapoints = n_timepoints_to_simulate # Getting the initial quantities from the basic model basic_model = self.basic_renewal.sample( - n_timepoints_to_simulate=n_timepoints, + n_timepoints_to_simulate=n_datapoints, data_observed_infections=None, padding=padding, **kwargs, @@ -199,35 +201,15 @@ def sample( latent_infections=basic_model.latent_infections, **kwargs, ) - i0_size = len(latent_hosp_admissions) - n_timepoints - if self.hosp_admission_obs_process_rv is None: - observed_hosp_admissions = None - else: - if data_observed_hosp_admissions is None: - ( - observed_hosp_admissions, - *_, - ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[i0_size + padding :], - obs=data_observed_hosp_admissions, - **kwargs, - ) - else: - data_observed_hosp_admissions = au.pad_x_to_match_y( - data_observed_hosp_admissions, - latent_hosp_admissions, - jnp.nan, - pad_direction="start", - ) - - ( - observed_hosp_admissions, - *_, - ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[i0_size + padding :], - obs=data_observed_hosp_admissions[i0_size + padding :], - **kwargs, - ) + + ( + observed_hosp_admissions, + *_, + ) = self.hosp_admission_obs_process_rv.sample( + mu=latent_hosp_admissions[-n_datapoints:], + obs=data_observed_hosp_admissions, + **kwargs, + ) return HospModelSample( Rt=basic_model.Rt, diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 2b5498f8..afbdb8d0 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -187,9 +187,9 @@ def sample( "Cannot pass both n_timepoints_to_simulate and data_observed_infections." ) elif n_timepoints_to_simulate is None: - n_timepoints = len(data_observed_infections) + n_timepoints = len(data_observed_infections) + padding else: - n_timepoints = n_timepoints_to_simulate + n_timepoints = n_timepoints_to_simulate + padding # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) Rt, *_ = self.Rt_process_rv.sample( @@ -210,9 +210,6 @@ def sample( **kwargs, ) - if data_observed_infections is not None: - data_observed_infections = data_observed_infections[padding:] - observed_infections, *_ = self.infection_obs_process_rv.sample( mu=post_seed_latent_infections[padding:], obs=data_observed_infections, diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index d2c233b3..c18bec67 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import numpyro as npro import numpyro.distributions as dist +from numpyro.contrib.control_flow import scan from pyrenew.metaclass import RandomVariable @@ -62,12 +63,20 @@ def sample( if init is None: init = npro.sample(name + "_init", self.error_distribution) - diffs = npro.sample( - name + "_diffs", - self.error_distribution.expand((n_timepoints - 1,)), + + def transition(x_prev, _): + # numpydoc ignore=GL08 + diff = npro.sample(name + "_diffs", self.error_distribution) + x_curr = x_prev + diff + return x_curr, x_curr + + _, x = scan( + transition, + init=init, + xs=jnp.arange(n_timepoints - 1), ) - return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) + return (jnp.hstack([init, x]),) @staticmethod def validate(): diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py new file mode 100644 index 00000000..677755c5 --- /dev/null +++ b/model/src/test/test_forecast.py @@ -0,0 +1,78 @@ +# numpydoc ignore=GL08 + +import jax.numpy as jnp +import jax.random as jr +import numpy as np +import numpyro as npro +import numpyro.distributions as dist +import pyrenew.transformation as t +from numpy.testing import assert_array_equal +from pyrenew.deterministic import DeterministicPMF +from pyrenew.latent import ( + Infections, + InfectionSeedingProcess, + SeedInfectionsZeroPad, +) +from pyrenew.metaclass import DistributionalRV +from pyrenew.model import RtInfectionsRenewalModel +from pyrenew.observation import PoissonObservation +from pyrenew.process import RtRandomWalkProcess + + +def test_forecast(): + """Check that forecasts are the right length and match the posterior up until forecast begins.""" + pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) + gen_int = DeterministicPMF(pmf_array, name="gen_int") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + t_unit=1, + ) + latent_infections = Infections() + observed_infections = PoissonObservation() + rt = RtRandomWalkProcess( + Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), + Rt_transform=t.ExpTransform().inv, + Rt_rw_dist=dist.Normal(0, 0.025), + ) + model = RtInfectionsRenewalModel( + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, + ) + + n_timepoints_to_simulate = 30 + n_forecast_points = 10 + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + model_sample = model.sample( + n_timepoints_to_simulate=n_timepoints_to_simulate + ) + + model.run( + num_warmup=5, + num_samples=5, + data_observed_infections=model_sample.observed_infections, + rng_key=jr.key(54), + ) + + posterior_predictive_samples = model.posterior_predictive( + n_timepoints_to_simulate=n_timepoints_to_simulate + n_forecast_points, + ) + + # Check the length of the predictive distribution + assert ( + len(posterior_predictive_samples["poisson_rv"][0]) + == n_timepoints_to_simulate + n_forecast_points + ) + + # Check the first elements of the posterior predictive Rt are the same as the + # posterior Rt + assert_array_equal( + model.mcmc.get_samples()["Rt"][0], + posterior_predictive_samples["Rt"][0][ + : len(model.mcmc.get_samples()["Rt"][0]) + ], + ) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index c8c53ec3..f79a7e4c 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -267,18 +267,17 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 # Sampling and fitting model 1 (with obs infections) np.random.seed(2203) + pad_size = 5 with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - model1_samp = model1.sample(n_timepoints_to_simulate=30) - - new_obs = jnp.hstack( - [jnp.repeat(jnp.nan, 5), model1_samp.observed_infections[5:]], - ) + model1_samp = model1.sample( + n_timepoints_to_simulate=30, padding=pad_size + ) model1.run( num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=new_obs, + data_observed_infections=model1_samp.observed_infections, padding=5, ) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index d6d5c023..dc2b090a 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -253,13 +253,13 @@ def test_model_hosp_no_obs_model(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model0_samp = model0.sample(n_timepoints_to_simulate=30) - model0.observation_process = NullObservation() + model0.hosp_admission_obs_process_rv = NullObservation() np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) - np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt) + np.testing.assert_array_almost_equal(model0_samp.Rt, model1_samp.Rt) np.testing.assert_array_equal( model0_samp.latent_infections, model1_samp.latent_infections ) @@ -493,6 +493,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) n_obs_to_generate = 30 + pad_size = 5 I0 = InfectionSeedingProcess( "I0_seeding", @@ -536,17 +537,17 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) # Other random components + total_length = n_obs_to_generate + pad_size + gen_int.size() weekday = jnp.array([1, 1, 1, 1, 2, 2]) weekday = weekday / weekday.sum() weekday = jnp.tile(weekday, 10) - # weekday = weekday[:n_obs_to_generate] - weekday = weekday[:34] + weekday = weekday[:total_length] weekday = DeterministicVariable(weekday, name="weekday") hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4]) hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10) - hosp_report_prob_dist = hosp_report_prob_dist[:34] + hosp_report_prob_dist = hosp_report_prob_dist[:total_length] hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum() hosp_report_prob_dist = DeterministicVariable( @@ -572,23 +573,19 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) # Sampling and fitting model 0 (with no obs for infections) + np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - model1_samp = model1.sample(n_timepoints_to_simulate=n_obs_to_generate) + model1_samp = model1.sample( + n_timepoints_to_simulate=n_obs_to_generate, padding=pad_size + ) - pad_size = 5 - obs = jnp.hstack( - [ - jnp.repeat(jnp.nan, pad_size), - model1_samp.observed_hosp_admissions[pad_size:], - ] - ) # Running with padding model1.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=obs, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, padding=pad_size, ) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index f173012c..a2342552 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -94,6 +94,9 @@ def test_rng_keys_produce_correct_samples(): # set up base models for testing models = [create_test_model() for _ in range(5)] n_timepoints_to_simulate = [30] * len(models) + n_timepoints_posterior_predictive = [ + x + models[0].gen_int_rv.size() for x in n_timepoints_to_simulate + ] # sample only a single model and use that model's samples # as the observed_infections for the rest of the models with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): @@ -115,7 +118,9 @@ def test_rng_keys_produce_correct_samples(): posterior_predictive_list = [ posterior_predictive_test_model(*elt) - for elt in list(zip(models, n_timepoints_to_simulate, rng_keys)) + for elt in list( + zip(models, n_timepoints_posterior_predictive, rng_keys) + ) ] # using same rng_key should get same run samples assert_array_equal( diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 9f1335e1..bd56d910 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -16,7 +16,7 @@ def test_rw_can_be_sampled(): with numpyro.handlers.seed(rng_seed=62): # can sample with and without inits - ans0 = rw_normal.sample(3532, init=jnp.array([50.0])) + ans0 = rw_normal.sample(3532, init=50.0) ans1 = rw_normal.sample(5023) # check that the samples are of the right shape @@ -35,9 +35,9 @@ def test_rw_samples_correctly_distributed(): [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - init_arr = jnp.array([532.0]) + rw_init = 532.0 with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal.sample(n_samples, init=init_arr) + samples, *_ = rw_normal.sample(n_samples, init=rw_init) # Checking the shape assert samples.shape == (n_samples,) @@ -60,4 +60,4 @@ def test_rw_samples_correctly_distributed(): assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) # first value should be the init value - assert_almost_equal(samples[0], init_arr) + assert_almost_equal(samples[0], rw_init) diff --git a/pyproject.toml b/pyproject.toml index 27ab0504..8227e76f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ packages = [{include = "multisignal_epi_inference"}] [tool.poetry.dependencies] python = "^3.12" +numpyro = ">=0.15.1" [tool.poetry.group.dev] optional = true