-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add scan primitive #595
Add scan primitive #595
Conversation
@neerajprad @eb8680 I have made a def target(T=10, q=1, r=1, phi=0., beta=0.):
def transition(state, i):
x0, mu0 = state
x1 = numpyro.sample('x', dist.Normal(phi * x0, q))
mu1 = beta * mu0 + x1
y1 = numpyro.sample('y', dist.Normal(mu1, r))
return (x1, mu1), (x1, y1)
mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
y0 = numpyro.sample('y_0', dist.Normal(mu0, r))
_, xy = scan("scan", transition, (x0, mu0), np.arange(1, T))
x, y = xy
return np.append(x0, x), np.append(y0, y) returns
as expected. (I marked this as "blocked" because I want to address |
# XXX: this won't work with MultivariateNormal, which has different interpretations | ||
# depending on which one specified among `scale_tril`, `covariance_matrix`, `precision_matrix` | ||
# TODO: arrange keys in arg_constraints to be the same order as constructor | ||
return tuple(getattr(self, param) for param in self.arg_constraints.keys()), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this work with all distribution classes, such as MaskedDistribution
and ExpandedDistribution
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I think it should work but we need to overwrite the methods in such subclasses. Auxiliary information such as batch_shape
, reinterpreted_batch_ndims
can be put into the second output (instead of None
here) and it will be recovered correctly by the argument aux_data
in tree_unflatten
. mask
can be scanned as other arg_constraints
. TransformedDistribution is trickier, I think we need to register pytree for them too. We can add try/except here (e.g. to check if self.arg_constraints
is non-empty, otherwise raise an error to ask users to submit a feature request. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to provide Distribution.flatten
or Distribution.unflatten
methods if we need to do this, instead of baking in assumptions in arg_constraints
, though let us hold off on that until we finalize the approach. The good thing about this change is that it is self-contained, and can be easily reverted later.
numpyro/contrib/control_flow/scan.py
Outdated
def scan(name, fn, init_value, xs, rng_key=None): | ||
# if there are no active Messengers, we just run and return it as expected: | ||
if not _PYRO_STACK: | ||
(rng_key, carry), (site_values, site_dists, ys) = scan_wrapper(fn, init_value, xs, rng_key=rng_key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we had a model where one of the sample sites was inside a scan body, and we ran inference and collected some samples, can we later use handlers.condition
to condition the site to observed samples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to modify condition
a bit, as in substitute
to push param_map
to control_flow
primitive (probably to distinguish substitute
and condition
, we name it condition_param_map
vs substitute_param_map
). Then I think it should work. Later, we need to record more information from the trace such as mask
, scale
, is_observed
... but for now, I think we can check if they are None or False and raise an error that asks for feature requests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so currently this supports trace
, seed
, how about other effect handlers? It will be nice if whatever we have can transparently support additional effect handlers without causing any surprise. I will have to play with this though to clearly understand what are the issues involved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the point is how much information we want to collect from the trace
(which we can return instead of returning distributions for the output of scan_wrapper
as in the current implementation), which records more information. So far, sample
primitive needs three more information mask
, scale
, cond_indep_stack
. The first two ones are easily recorded. The last one is non-trivial but I guess we can resolve it by adding cond_indep_stack
key to control_flow
msg and based on cond_indep_stack
, we'll wrap the scan body_fn by corresponding plate statements (before getting body_fn trace). Anyway, it is not important I guess (our predictive utilities using vmap
instead of plate
to draw multiple samples).
I think I have a better understanding of the issues involved, thanks to @fehiepsi's PR. @eb8680 - Regarding our earlier discussion, pushing the handlers into the loop body will probably work for many handlers, but there are a few immediate issues:
def scan_body(c, y):
x, i = c
x = pyro.sample('x_{}'.format(i), ...)
return (x, i+1), x
There are probably other issues, but these were some of the immediate ones I could see. Note that we can already record the values and return the trace by just using (2), but it will be nice if we could preserve composability as well. |
@neerajprad Can we use a local |
Interesting, I'm not sure how that will work. Will need to think a bit.
That's right. I was more referring to the PR that you pointed out where we can tack on auxiliary information so that it works with
A limited version in contrib seems fine. It will still be nice if we can use |
I think Predictive should work as expected because it uses lax.map or vmap for predictions so we don't have to worry about batching num_samples as in Pyro. The current implementation can handle substitute_fn and param_map in e.g. computing potential_fn or substitute init strategies out of the box. Do I misunderstand your point? |
…th unconstrained messenger, add examples to docs
Thanks, @jawolf314! I just added init file and updated the implementation to make it compatible with recent refactors. If you are using |
@fehiepsi I have run the SGT examples from the NumPyro documentation (see items 1 and 2 below). The SGT examples do not use the sample statement inside of the function called by scan. I also wrote a structural time series model (LocalLinear trend like add.local.linear.trend with a seasonal component -- code at bottom of this comment). The structural time series model includes sample statements inside of the transition function called by scan -- in contrast to the SGT model. I then tried producing forecasts using the two different approaches from the NumPyro documentation
I have come across behavior with the Predictive class that I do not understand. When I use the Predictive class and the new scan function on the SGT model as demonstrated here time_series_forecasting.ipynb, the results are similar to the results currently in the NumPyro Tutorials section on time series forecasting -- where the separate function is used to describe the model for the inference stage ( Structural Time Series exampleI have been using the airline-passenger.csv dataset for testing (https://github.com/jbrownlee/Datasets/blob/master/airline-passengers.csv). In addition to installing NumPyro from 4c9786d, I also have jupyter, matplotlib and pandas installed in the environment (using Anaconda). The code below is taken from a Jupyter notebook where the dash lines represent new cells. The below code is modified from the NumPyro SGT examples. import os
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random, vmap
from jax.nn import softmax
import numpyro; numpyro.set_host_device_count(4)
import numpyro.distributions as dist
from numpyro.distributions import transforms
from numpyro.diagnostics import autocorrelation, hpdi
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.contrib.control_flow.scan import scan
#assert numpyro.__version__.startswith('0.2.4')
# ------------------------------------------------------------------------
# Read Data
dataset = pd.read_csv("airline-passengers.csv", index_col=0)
y_data = dataset.Passengers.values
data = y_data*jnp.ones(len(y_data))
data = jnp.log(data)
split = int(len(data)*0.9)
y_train, y_test = jnp.array(data[:split], dtype=jnp.float32), data[split:]
# ------------------------------------------------------------------------
# Params
num_seasons = 12
# ------------------------------------------------------------------------
# Model
def locallinear(y=None, num_seasons=12, future=0):
N = y.shape[0]
duration = N + future
# empirical
sd_y = jnp.std(y)
y_init = y[0]
dy_init = y[1] - y[0]
max_y = jnp.abs(jnp.max(y))
cauchy_sd = max_y * 0.01
# priors
season = numpyro.sample('season_init', dist.Cauchy(0, jnp.abs(y[:num_seasons]) * 0.3))
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level_init', dist.Normal(y_init, sd_y))
slope = numpyro.sample('slope_init', dist.Normal(dy_init, sd_y))
level_scale = numpyro.sample('level_scale', dist.HalfCauchy(cauchy_sd))
slope_scale = numpyro.sample('slope_scale', dist.HalfCauchy(cauchy_sd))
y_scale = numpyro.sample('y_scale', dist.HalfCauchy(cauchy_sd))
# state transition
def transition(state, t):
level, slope, season = state
season_t = -1 * jnp.sum(season[:-1])
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level', dist.Normal(level + slope, level_scale))
slope = numpyro.sample('slope', dist.Normal(slope, slope_scale))
# combined state
z_t = level + season_t
return (level, slope, season), z_t
(level, slope, season), z_t = scan(transition, (level, slope, season), jnp.arange(1, duration))
# observation equation
if future == 0:
y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale), obs=y[1:])
else:
z_t = z_t[y.shape[0] - 1:]
assert z_t.shape[0] == future
y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale))
return level, slope, season
# ------------------------------------------------------------------------
# Inference
kernel = NUTS(locallinear)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(1), y_train, num_seasons=num_seasons)
mcmc.print_summary()
samples = mcmc.get_samples()
# ------------------------------------------------------------------------
# Forecast using approach here http://pyro.ai/numpyro/time_series_forecasting.html
def locallinear_forecast(future, sample, y, level, slope, season):
y_scale = sample["y_scale"]
level_scale = sample['level_scale']
slope_scale = sample['slope_scale']
#yfs = [0] * future
for t in range(future):
# state transition
season_t = -1 * jnp.sum(season[:-1])
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level{}'.format(t), dist.Normal(level + slope, level_scale))
slope = numpyro.sample('slope{}'.format(t), dist.Normal(slope, slope_scale))
z_t = level + season_t
# observation
yf = numpyro.sample("yf[{}]".format(t), dist.Normal(z_t, y_scale))
#yfs[t] = yf
def forecast(future, rng_key, sample, y, num_seasons):
level, slope, season = handlers.substitute(locallinear, sample)(y, num_seasons)
forecast_model = handlers.seed(locallinear_forecast, rng_key)
forecast_trace = handlers.trace(forecast_model).get_trace(future, sample, y, level, slope, season)
results = [forecast_trace["yf[{}]".format(t)]["value"]
for t in range(future)]
return jnp.stack(results, axis=0)
rng_keys = random.split(random.PRNGKey(1), samples["y_scale"].shape[0])
forecast_marginal = vmap(lambda rng_key, sample: forecast(
len(y_test), rng_key, sample, y_train, num_seasons=num_seasons))(rng_keys, samples)
# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3)
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50))
# ------------------------------------------------------------------------
# Forecast using approach here https://github.com/pyro-ppl/numpyro/blob/87e94c0c988e0092ff0b29f4651ffcd022d68bec/notebooks/source/time_series_forecasting.ipynb
predictive = Predictive(locallinear, samples, return_sites=["y"])
forecast_marginal = predictive(random.PRNGKey(1), y_train, num_seasons=num_seasons, future=len(y_test))["y"]
# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3)
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50)) One further reason I think it may be related to the new scan function is that when I write the same model without sample statements inside of the transition function called by scan, both methods of producing the forecasts give nearly identical results. import os
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random, vmap
from jax.nn import softmax
import numpyro; numpyro.set_host_device_count(4)
import numpyro.distributions as dist
from numpyro.distributions import transforms
from numpyro.diagnostics import autocorrelation, hpdi
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.contrib.control_flow.scan import scan
#assert numpyro.__version__.startswith('0.2.4')
# ------------------------------------------------------------------------
# Read Data
dataset = pd.read_csv("airline-passengers.csv", index_col=0)
y_data = dataset.Passengers.values
data = y_data*jnp.ones(len(y_data))
data = jnp.log(data)
split = int(len(data)*0.9)
y_train, y_test = jnp.array(data[:split], dtype=jnp.float32), data[split:]
# ------------------------------------------------------------------------
# Params
num_seasons = 12
# ------------------------------------------------------------------------
# Model
def locallinear(y=None, num_seasons=12, future=0):
N = y.shape[0]
duration = N + future
# empirical
sd_y = jnp.std(y)
y_init = y[0]
dy_init = y[1] - y[0]
max_y = jnp.abs(jnp.max(y))
cauchy_sd = max_y * 0.01
# priors
season = numpyro.sample('season_init', dist.Cauchy(0, jnp.abs(y[:num_seasons]) * 0.3))
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level_init', dist.Normal(y_init, sd_y))
slope = numpyro.sample('slope_init', dist.Normal(dy_init, sd_y))
#level_scale = numpyro.sample('level_scale', dist.HalfCauchy(cauchy_sd))
#slope_scale = numpyro.sample('slope_scale', dist.HalfCauchy(cauchy_sd))
y_scale = numpyro.sample('y_scale', dist.HalfCauchy(cauchy_sd))
# state transition
def transition(state, t):
level, slope, season = state
season_t = -1 * jnp.sum(season[:-1])
season = jnp.concatenate([season[1:], season[:1]], axis=0)
#level = numpyro.sample('level', dist.Normal(level + slope, level_scale))
#slope = numpyro.sample('slope', dist.Normal(slope, slope_scale))
level = level + slope
# combined state
z_t = level + season_t
return (level, slope, season), z_t
(level, slope, season), z_t = scan(transition, (level, slope, season), jnp.arange(1, duration))
# observation equation
if future == 0:
y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale), obs=y[1:])
else:
z_t = z_t[y.shape[0] - 1:]
assert z_t.shape[0] == future
y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale))
return level, slope, season
# ------------------------------------------------------------------------
# Inference
kernel = NUTS(locallinear)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(1), y_train, num_seasons=num_seasons)
mcmc.print_summary()
samples = mcmc.get_samples()
# ------------------------------------------------------------------------
# Forecast using approach here http://pyro.ai/numpyro/time_series_forecasting.html
def locallinear_forecast(future, sample, y, level, slope, season):
y_scale = sample["y_scale"]
#level_scale = sample['level_scale']
#slope_scale = sample['slope_scale']
#yfs = [0] * future
for t in range(future):
# state transition
season_t = -1 * jnp.sum(season[:-1])
season = jnp.concatenate([season[1:], season[:1]], axis=0)
#level = numpyro.sample('level{}'.format(t), dist.Normal(level + slope, level_scale))
level = level + slope
#slope = numpyro.sample('slope{}'.format(t), dist.Normal(slope, slope_scale))
z_t = level + season_t
# observation
yf = numpyro.sample("yf[{}]".format(t), dist.Normal(z_t, y_scale))
#yfs[t] = yf
def forecast(future, rng_key, sample, y, num_seasons):
level, slope, season = handlers.substitute(locallinear, sample)(y, num_seasons)
forecast_model = handlers.seed(locallinear_forecast, rng_key)
forecast_trace = handlers.trace(forecast_model).get_trace(future, sample, y, level, slope, season)
results = [forecast_trace["yf[{}]".format(t)]["value"]
for t in range(future)]
return jnp.stack(results, axis=0)
rng_keys = random.split(random.PRNGKey(1), samples["y_scale"].shape[0])
forecast_marginal = vmap(lambda rng_key, sample: forecast(
len(y_test), rng_key, sample, y_train, num_seasons=num_seasons))(rng_keys, samples)
def forecast(future, rng_key, sample, y, num_seasons):
level, slope, season = handlers.substitute(locallinear, sample)(y, num_seasons)
forecast_model = handlers.seed(locallinear_forecast, rng_key)
forecast_trace = handlers.trace(forecast_model).get_trace(future, sample, y, level, slope, season)
results = [forecast_trace["yf[{}]".format(t)]["value"]
for t in range(future)]
return jnp.stack(results, axis=0)
rng_keys = random.split(random.PRNGKey(1), samples["y_scale"].shape[0])
forecast_marginal = vmap(lambda rng_key, sample: forecast(
len(y_test), rng_key, sample, y_train, num_seasons=num_seasons))(rng_keys, samples)
# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3)
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50))
# ------------------------------------------------------------------------
# Forecast using approach here https://github.com/pyro-ppl/numpyro/blob/87e94c0c988e0092ff0b29f4651ffcd022d68bec/notebooks/source/time_series_forecasting.ipynb
predictive = Predictive(locallinear, samples, return_sites=["y"])
forecast_marginal = predictive(random.PRNGKey(1), y_train, num_seasons=num_seasons, future=len(y_test))["y"]
# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3)
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50)) |
Whoa, we really love examples that shows something is wrong! I'll look deeper into your example. Thanks, @jawolf314! |
@jawolf314 I think I know the reason. In the current implementation, we need the condition value (the value from posterior, which has the same size as the training data) to have the same size as the scan length, but we didn't raise an error if that is not the case. With the current implementation and with the help of last_val, ys_train = scan(fn, init_val, xs_train)
_, ys_pred = scan(namespace(fn, 'pred'), last_val, xs_test) then you can collect But after getting up this morning, I think we can support the case: scan length > length of the condition value using Btw, if things work well, could you help us by contributing your example (like some in examples folder) or even better writing a small tutorial with the above model? Currently, we don't have the bandwidth to both develop new features and explore useful applications to illustrate the features that we added. So your contributions are very welcome and I believe they will be very helpful for other users. Thanks!! |
@jawolf314 I have made an enhancement for your usage case. Please let me know if things work as expected for you. :) (I might make a couple of enhancements to get rid of some grosses in the code after discussing with other devs - but I'll make sure that things will work as expected for you. If you observe something still wrong, please let us know.) |
Thanks @fehiepsi!
The results using both methods give similar results after updating the model to as you suggested. I just saw you pushed d7583fc. Will test the original version of the locallinear model on the updated code and comment shortly. def locallinear(y=None, num_seasons=12, future=0):
N = y.shape[0]
duration = N + future
#print(N, duration)
# empirical
sd_y = jnp.std(y)
y_init = y[0]
dy_init = y[1] - y[0]
max_y = jnp.abs(jnp.max(y))
cauchy_sd = max_y * 0.01
# priors
season = numpyro.sample('season_init', dist.Cauchy(0, jnp.abs(y[:num_seasons]) * 0.3))
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level_init', dist.Normal(y_init, sd_y))
slope = numpyro.sample('slope_init', dist.Normal(dy_init, sd_y))
level_scale = numpyro.sample('level_scale', dist.HalfCauchy(cauchy_sd))
slope_scale = numpyro.sample('slope_scale', dist.HalfCauchy(cauchy_sd))
y_scale = numpyro.sample('y_scale', dist.HalfCauchy(cauchy_sd))
# state transition
def transition(state, t):
level, slope, season = state
season_t = -1 * jnp.sum(season[:-1])
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level', dist.Normal(level + slope, level_scale))
slope = numpyro.sample('slope', dist.Normal(slope, slope_scale))
# combined state
z_t = level + season_t
return (level, slope, season), z_t
# added https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
def transition_pred(state, t):
level, slope, season = state
season_t = -1 * jnp.sum(season[:-1])
season = jnp.concatenate([season[1:], season[:1]], axis=0)
level = numpyro.sample('level_pred', dist.Normal(level + slope, level_scale))
slope = numpyro.sample('slope_pred', dist.Normal(slope, slope_scale))
# combined state
z_t = level + season_t
return (level, slope, season), z_t
# updated following https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
(level, slope, season), z_t = scan(transition, (level, slope, season), jnp.arange(1, N))
# added following https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
if future != 0:
_, z_t_pred = scan(transition_pred, (level, slope, season), jnp.arange(N, duration))
# observation equation
if future == 0:
y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale), obs=y[1:])
else:
# updated following https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
assert z_t_pred.shape[0] == future
y_pred = numpyro.sample('y', dist.Normal(z_t_pred, y_scale))
return level, slope, season
Yes, thank you. I am happy to contribute an example. |
Thanks for the feedback, @jawolf314! Glad that it works. 👯♂️ |
This PR is ready to review. I am pretty happy with the current implementation (all gross points that I put in the comments of previous commits are removed now). :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will take a closer look at the distribution flattening/ unflattening logic. This is looking really good, and is a very important feature to have. Look forward to merging this soon.
@@ -214,7 +214,7 @@ def __init__(self, df, validate_args=None): | |||
|
|||
@copy_docs_from(Distribution) | |||
class GaussianRandomWalk(Distribution): | |||
arg_constraints = {'num_steps': constraints.positive_integer, 'scale': constraints.positive} | |||
arg_constraints = {'scale': constraints.positive, 'num_steps': constraints.positive_integer} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for this change?
EDIT: I see why we have this. Note my comment below.
""" | ||
This primitive scans a function over the leading array axes of | ||
`xs` while carrying along state. See :func:`jax.lax.scan` for more | ||
information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi - What currently are the limitations of scan? e.g. distributions, compatibility with other handlers, etc. Maybe we can add a note to that effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current limitation is:
- It only records
sample
anddeterministic
sites inf
, notparam
. The main reason is we need to have logics to flatten/unflattenconstraint
field and I was lazy to do so. - I think most handlers (e.g.
scale
,mask
) will work because we doapply_stack
for scanned messages. The one that won't work isplate
, where users will need to moveplate
statements insidef
. I'll add a note to it. Do you suspect something else that won't work out of the box? - For distributions, I have raised errors for not supported distributions and told users to report the issue to us.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is totally fine to not support param
, to begin with. I think most effect handlers should work out of the box, but let me take a look in detail.
We can add a disclaimer.
.. warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, sample
and deterministic
sites within the scan body are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Neeraj, I'll add it.
cls.tree_flatten, | ||
cls.tree_unflatten) | ||
|
||
def tree_flatten(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is probably better to not expose a default implementation. The reason is that this relies on an implementation detail of cpython (python 3.6+) in that dictionary keys will be sorted. Can we raise NotImplementedError
for both of these? If the default works for a lot of distributions, we could use an OrderedDict for arg_constraints
. This is however an implicit ordering constraint that users will need to know if they are writing their own distribution classes, so it will be best to just raise an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This default implementation works because most distributions only need ndarray args at their constructor (some special cases such as event_dim
in Delta or covariance_matrix=None
in MVN or dimension=4
in LKJ,...). I will try to use inspect
module here to guarantee that the order of args is matched.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, another possibility is to sort by key here:
return tuple(getattr(self, param) for param in sorted(self.arg_constraints.keys())), None
and then, for unflatten:
return cls(**dict(zip(sorted(self.arg_constraints.keys()), params)))
Would that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, nice trick at the unflatten! Yes, that should work. Many thanks!
# this branch happens when substitute_fn is init_strategy, | ||
# where we apply init_strategy to each element in the scanned series | ||
return value | ||
elif value_ndim == fn_ndim + 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the time dim is always taken to be the first dim, even in the presence of outermost plate dims? I think that should be fine, but could you add a note in the docstring/comment, if that's the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, time dim will be the first dim. I think that is covered in the note in the docs (plate
cannot be used as a context for scan
). We also raise an error if users do
with plate(...):
scan(...)
numpyro/contrib/control_flow/scan.py
Outdated
else: | ||
raise RuntimeError(f"Substituted value for site {site['name']} " | ||
"requires length greater than or equal to scan length." | ||
f" Expected length >= {length}, but got {shape[0]}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be lesser than scan length (length of time dim). This seems to suggest otherwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops... thanks for pointing it out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just had some minor clarifying comments. Happy to merge otherwise! 🎉
Fixes #566.
This is one solution for #566. As mentioned in that thread, we can rewrite the body_fn of
scan
so that it is compatible with our inference utilities. Basically, I borrow the idea of pyro.reparam (consider latent sites in scan body_fn as deterministic sites, transformed from something - here, it is simply an unpacking transform, which unpacks values collected from scan) andadd a convenient class ScanDistribution to handlecreate a primitivesample
andlog_prob
separatelycontrol_flow
to letsubstitute
handler push its param_map to scan.TODO