-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ARProcess and InfectionInitializationProcess to handle batched input #423
Comments
Minimal reproducible example? Which inputs are causing Also, what happens if you instantiate the process RV outside the plate but call its sample() method within the plate? |
|
This reproduces the issue with using
one line error summary
|
This is blocking CDCgov/pyrenew-hew#7 |
Suggest splitting this into sub-issues, one for each of the two RV classes. |
Rewriting to have the initial exponential growth work in terms of matrix multiplication seem to do what we want. However, it still feels a bit automatic. For instance, how does it know how to batch I get UserWarning: Missing a plate statement for batch dimension -2 at site 'obs'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(rng_key, y=y_data) but I'm not sure where to put this import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
n_subpops = 3
rate = 1 + jnp.pow(10.0, -(jnp.arange(n_subpops) + 1))
n_timepoints = 10
i0 = jnp.arange(n_subpops) + 1
y_data = i0 * jnp.exp(rate * jnp.expand_dims(jnp.arange(n_timepoints), 1))
def my_model(y):
with numpyro.plate("subpop", n_subpops):
rate = numpyro.sample("rate", dist.HalfNormal())
i0 = numpyro.sample("i0", dist.HalfNormal())
mean_infec = i0 * jnp.exp(
rate * jnp.expand_dims(jnp.arange(n_timepoints), 1)
)
numpyro.sample("obs", dist.Poisson(mean_infec), obs=y)
# Posterior Sampling
nuts_kernel = NUTS(my_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y=y_data)
# Check results
mcmc.print_summary() |
Maybe a second |
For the warning, wo/ mean_infec = i0 * jnp.exp(
rate * jnp.expand_dims(jnp.arange(n_timepoints), 1)
)
numpyro.sample("obs", dist.Poisson(mean_infec), obs=y) out:
and w/ the
|
Modified, for no warning: def my_model(y):
with numpyro.plate("subpop", n_subpops):
rate = numpyro.sample("rate", dist.HalfNormal())
i0 = numpyro.sample("i0", dist.HalfNormal())
with numpyro.plate("timepoint", n_timepoints):
mean_infec = i0 * jnp.exp(
rate * jnp.expand_dims(jnp.arange(n_timepoints), 1)
)
numpyro.sample("obs", dist.Poisson(mean_infec), obs=y) |
Thanks @AFg6K7h4fhy2. I'm glad we have something working without a warning, but this all feels a bit funny to me. In particular, it seems like we should be able to come up with a solution that doesn't require the time series to be of equal length. |
Agree. I'm not going to investigate the |
This works with arbitrary time series sizes. Doesn't feel very modular, though. import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from jax import vmap
n_subpops = 4
rates = 1 + jnp.pow(10.0, -(jnp.arange(n_subpops) + 1))
n_timepoints = jnp.arange(n_subpops) + 10
i0s = jnp.arange(n_subpops) + 1
y_data = jnp.concatenate(
[
i0 * jnp.exp(rate * jnp.arange(n_timepoints))
for rate, i0, n_timepoints in zip(rates, i0s, n_timepoints)
]
)
y_ind = jnp.repeat(jnp.arange(n_subpops), n_timepoints)
y_time = jnp.concatenate(
[jnp.arange(n_timepoint) for n_timepoint in n_timepoints]
)
def my_model(y_data, y_ind, y_time):
with numpyro.plate("subpop", n_subpops):
rate = numpyro.sample("rate", dist.HalfNormal())
i0 = numpyro.sample("i0", dist.HalfNormal())
mean_infec = i0[y_ind] * jnp.exp(rate[y_ind] * y_time)
numpyro.sample("obs", dist.Poisson(mean_infec), obs=y_data)
# Posterior Sampling
nuts_kernel = NUTS(my_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y_data=y_data, y_ind=y_ind, y_time=y_time)
# Check results
mcmc.print_summary() |
After discussing with @dylanhmorris, we have agreed on
I am working on a revised version of the above model to demonstrate these recommendations. |
Updated example based on above recommendations is below. A bit unclear about what this implies for the rest of PyRenew. I think we may need to stick in some import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
import polars as pl
import string
n_groups = 4
rates = 1 + jnp.pow(10.0, -(jnp.arange(n_groups) + 1))
n_timepoints = jnp.arange(n_groups) + 10
i0s = jnp.arange(n_groups) + 1
input_data = pl.DataFrame(
{
"group": pl.Series(np.array(list(string.ascii_lowercase))[
np.repeat(np.arange(n_groups), n_timepoints)
], dtype = pl.Categorical),
"time": np.concatenate(
[np.arange(n_timepoint) for n_timepoint in n_timepoints]
),
"obs": np.concatenate(
[
i0 * np.exp(rate * np.arange(n_timepoints))
for rate, i0, n_timepoints in zip(rates, i0s, n_timepoints)
]
),
}
).filter(~((pl.col("group") == "a") & (pl.col("time") == 4)))
# some implicitly missing data
y_group = input_data["group"].to_numpy()
y_time = input_data["time"].to_numpy()
y_obs = input_data["obs"].to_numpy()
# This would be done at modle instantiation:
y_group_ind = input_data["group"].to_physical().to_numpy()
y_time_max = input_data["time"].max()
def my_model(y_group, y_time, y_obs):
with numpyro.plate("group", n_groups):
rate = numpyro.sample("rate", dist.HalfNormal())
i0 = numpyro.sample("i0", dist.HalfNormal())
mean_infec = i0 * jnp.exp(rate * jnp.arange(y_time_max+1)[:, jnp.newaxis])
numpyro.sample("obs", dist.Poisson(mean_infec[y_time, y_group]), obs=y_obs)
# Posterior Sampling
nuts_kernel = NUTS(my_model, find_heuristic_step_size = True)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y_group = y_group_ind, y_time = y_time, y_obs = y_obs)
# Check results
mcmc.print_summary() |
Was the Infection initialization process part of this issue handled by #432, @damonbayer? Or do other initialization schemes such as |
It's a non-trivial fix. I would like to first consider removing the |
I think that's reasonable. |
I think the |
Partially closed by #432 |
Fully closed by #439 |
Currently,
ARProcess
andInfectionInitializationProcess
(and possibly others) require scalar inputs which prevents these from being used withnumpyro.plate
.This emerged while trying to use
numpyro.plate
to calculate site-level dynamics for each site in CDCgov/pyrenew-hew#7I am also open to other possible ways of handling this and maybe worth having a discussion about the best way to do it. For reference, I tried using a
for
loop but that required modifying thename
arguments ofRandomVariables
. @dylanhmorris mentioned potentially usingjax.lax.scan
The text was updated successfully, but these errors were encountered: