Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scan primitive #595

Merged
merged 36 commits into from
Jun 24, 2020
Merged

Add scan primitive #595

merged 36 commits into from
Jun 24, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented May 15, 2020

Fixes #566.

This is one solution for #566. As mentioned in that thread, we can rewrite the body_fn of scan so that it is compatible with our inference utilities. Basically, I borrow the idea of pyro.reparam (consider latent sites in scan body_fn as deterministic sites, transformed from something - here, it is simply an unpacking transform, which unpacks values collected from scan) and add a convenient class ScanDistribution to handle sample and log_prob separately create a primitive control_flow to let substitute handler push its param_map to scan.

TODO

  • Handle constrained support. (Dynamic support will not be addressed in this PR because eventually, we will remove such stuff.) We need to register distribution classes as JAX pytrees so that we can return the transform classes along with site values.
  • Test on the state-space model in the forum.
  • Add test for Predictive
  • Add test for flatten, unflatten all distributions.

@fehiepsi
Copy link
Member Author

fehiepsi commented May 18, 2020

@neerajprad @eb8680 I have made a scan version that works as expected. Please let me know what you think. :) NUTS summary of the @dsheldon model

    def target(T=10, q=1, r=1, phi=0., beta=0.):

        def transition(state, i):
            x0, mu0 = state
            x1 = numpyro.sample('x', dist.Normal(phi * x0, q))
            mu1 = beta * mu0 + x1
            y1 = numpyro.sample('y', dist.Normal(mu1, r))
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
        y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

        _, xy = scan("scan", transition, (x0, mu0), np.arange(1, T))
        x, y = xy

        return np.append(x0, x), np.append(y0, y)

returns

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]      0.00      0.82     -0.02     -1.14      1.48    103.49      0.99
      x[1]     -0.16      1.01     -0.25     -1.71      1.25     92.78      0.99
      x[2]      0.04      1.04      0.14     -1.73      1.44    130.34      0.99
      x[3]     -0.02      0.99      0.04     -1.70      1.40    128.91      1.00
      x[4]      0.04      1.04      0.17     -1.45      1.91    263.83      0.99
      x[5]     -0.02      0.83      0.02     -1.24      1.30    116.25      0.99
      x[6]     -0.04      0.92     -0.03     -1.57      1.38     81.22      1.00
      x[7]     -0.09      1.19     -0.08     -2.03      1.94    125.77      0.99
      x[8]     -0.08      0.83     -0.19     -1.15      1.60    231.86      0.99
       x_0      0.06      1.00     -0.03     -1.30      1.79    148.38      1.00
      y[0]     -0.07      1.24      0.08     -1.87      1.95     64.97      1.00
      y[1]     -0.07      1.51      0.09     -2.61      2.33     82.72      0.99
      y[2]      0.04      1.49      0.25     -2.35      2.21     98.36      0.99
      y[3]     -0.12      1.49     -0.01     -2.82      2.04    121.21      1.01
      y[4]      0.12      1.55     -0.02     -2.80      2.13    226.02      0.99
      y[5]      0.01      1.18     -0.10     -1.99      1.92     77.37      0.99
      y[6]     -0.17      1.31     -0.20     -2.29      1.91     99.73      1.00
      y[7]     -0.05      1.67     -0.02     -2.30      2.66    131.85      0.99
      y[8]     -0.09      1.25     -0.06     -1.49      2.59    197.05      1.00
       y_0     -0.01      1.45     -0.01     -2.30      2.32    136.60      0.99

Number of divergences: 0

as expected.

(I marked this as "blocked" because I want to address reparam issue first, hence I can get rid of those base_param_map, param_map logic)

@fehiepsi fehiepsi mentioned this pull request May 19, 2020
14 tasks
# XXX: this won't work with MultivariateNormal, which has different interpretations
# depending on which one specified among `scale_tril`, `covariance_matrix`, `precision_matrix`
# TODO: arrange keys in arg_constraints to be the same order as constructor
return tuple(getattr(self, param) for param in self.arg_constraints.keys()), None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work with all distribution classes, such as MaskedDistribution and ExpandedDistribution?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I think it should work but we need to overwrite the methods in such subclasses. Auxiliary information such as batch_shape, reinterpreted_batch_ndims can be put into the second output (instead of None here) and it will be recovered correctly by the argument aux_data in tree_unflatten. mask can be scanned as other arg_constraints. TransformedDistribution is trickier, I think we need to register pytree for them too. We can add try/except here (e.g. to check if self.arg_constraints is non-empty, otherwise raise an error to ask users to submit a feature request. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to provide Distribution.flatten or Distribution.unflatten methods if we need to do this, instead of baking in assumptions in arg_constraints, though let us hold off on that until we finalize the approach. The good thing about this change is that it is self-contained, and can be easily reverted later.

def scan(name, fn, init_value, xs, rng_key=None):
# if there are no active Messengers, we just run and return it as expected:
if not _PYRO_STACK:
(rng_key, carry), (site_values, site_dists, ys) = scan_wrapper(fn, init_value, xs, rng_key=rng_key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we had a model where one of the sample sites was inside a scan body, and we ran inference and collected some samples, can we later use handlers.condition to condition the site to observed samples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to modify condition a bit, as in substitute to push param_map to control_flow primitive (probably to distinguish substitute and condition, we name it condition_param_map vs substitute_param_map). Then I think it should work. Later, we need to record more information from the trace such as mask, scale, is_observed... but for now, I think we can check if they are None or False and raise an error that asks for feature requests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so currently this supports trace, seed, how about other effect handlers? It will be nice if whatever we have can transparently support additional effect handlers without causing any surprise. I will have to play with this though to clearly understand what are the issues involved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point is how much information we want to collect from the trace (which we can return instead of returning distributions for the output of scan_wrapper as in the current implementation), which records more information. So far, sample primitive needs three more information mask, scale, cond_indep_stack. The first two ones are easily recorded. The last one is non-trivial but I guess we can resolve it by adding cond_indep_stack key to control_flow msg and based on cond_indep_stack, we'll wrap the scan body_fn by corresponding plate statements (before getting body_fn trace). Anyway, it is not important I guess (our predictive utilities using vmap instead of plate to draw multiple samples).

@neerajprad
Copy link
Member

neerajprad commented May 28, 2020

I think I have a better understanding of the issues involved, thanks to @fehiepsi's PR.

@eb8680 - Regarding our earlier discussion, pushing the handlers into the loop body will probably work for many handlers, but there are a few immediate issues:

  1. Naming of sites - If in the loop body, we have something like pyro.sample('x', ..), this will be called multiple times. We need the site name to be distinct. Another option is to batch the values in the trace, as in this PR (which is probably a better idea). If we instead have distinct names based on an index that is carried, it won't work again since i will be an abstract value, and the name will be something like x_Traced<ShapedArray(int32[]).
def scan_body(c, y):
  x, i = c
  x = pyro.sample('x_{}'.format(i), ...)
  return (x, i+1), x
  1. Trace handler - A bigger issue is with trace. We'll have to push .get_trace inside the scan body so that the site values / distribution arguments are captured as concrete values not abstract arrays. So we'll need to return the trace from the body function. I think that can be done as @fehiepsi has shown by having distributions be jax pytrees. I think if we go further and make the trace data structure itself be a pytree, we should be able to reconstruct the trace object, where site x above would be batched along the first dim. However, even if we are able to do so, composing handlers requires us to solve (1) in some regard. e.g. if we later want to condition the site x or replay based on a batched array, the condition handler will need to pull the right index from the batched array and not just replace x by a batched array. I think it might be possible to read off the handlers from the stack and pass these to scan, but I haven't thought through this completely.

There are probably other issues, but these were some of the immediate ones I could see. Note that we can already record the values and return the trace by just using (2), but it will be nice if we could preserve composability as well.

@fehiepsi
Copy link
Member Author

@neerajprad Can we use a local _PYRO_STACK (which keeps all information from global _PYRO_STACK) in body_fn of scan? That way, things might be more composable. In addition, trace is just a dict of pytrees, so we can return trace as an output of scan as far as I see. However, we still need to manipulate dim stuff because vmap stacks result in the first dimension by default. Anyway, unless we have some real applications, let's start with a limited version first (I believe that this already covers most usage cases - where users only want to convert naive loops in other languages into NumPyro code, without mask, scale, plate effect handlers).

@neerajprad
Copy link
Member

neerajprad commented May 29, 2020

@neerajprad Can we use a local _PYRO_STACK (which keeps all information from global _PYRO_STACK) in body_fn of scan? That way, things might be more composable.

Interesting, I'm not sure how that will work. Will need to think a bit.

In addition, trace is just a dict of pytrees, so we can return trace as an output of scan as far as I see. However, we still need to manipulate dim stuff because vmap stacks result in the first dimension by default.

That's right. I was more referring to the PR that you pointed out where we can tack on auxiliary information so that it works with vmap (unless I misunderstood that).

Anyway, unless we have some real applications, let's start with a limited version first (I believe that this already covers most usage cases - where users only want to convert naive loops in other languages into NumPyro code, without mask, scale, plate effect handlers).

A limited version in contrib seems fine. It will still be nice if we can use Predictive at the very least (which would be a very common thing to do), and that will probably require us to solve at least some of the issues? This is also quite important for funsor (and discrete enumeration), IIRC?

@fehiepsi
Copy link
Member Author

I think Predictive should work as expected because it uses lax.map or vmap for predictions so we don't have to worry about batching num_samples as in Pyro. The current implementation can handle substitute_fn and param_map in e.g. computing potential_fn or substitute init strategies out of the box. Do I misunderstand your point?

@fehiepsi fehiepsi removed the WIP label May 31, 2020
…th unconstrained messenger, add examples to docs
@fehiepsi
Copy link
Member Author

Thanks, @jawolf314! I just added init file and updated the implementation to make it compatible with recent refactors. If you are using scan primitive, please let me know if there is any gross. :)

@jawolf314
Copy link

jawolf314 commented Jun 17, 2020

@fehiepsi
I wanted to follow-up with an observation about how the Predictive class is working with the new scan primitive from your dev branch (4c9786d commit).

I have run the SGT examples from the NumPyro documentation (see items 1 and 2 below). The SGT examples do not use the sample statement inside of the function called by scan. I also wrote a structural time series model (LocalLinear trend like add.local.linear.trend with a seasonal component -- code at bottom of this comment). The structural time series model includes sample statements inside of the transition function called by scan -- in contrast to the SGT model. I then tried producing forecasts using the two different approaches from the NumPyro documentation

  1. Using Predictive class as in time_series_forecasting.ipynb
  2. Using a separate inference and forecasting function as in NumPyro Tutorials section on time series forecasting

I have come across behavior with the Predictive class that I do not understand. When I use the Predictive class and the new scan function on the SGT model as demonstrated here time_series_forecasting.ipynb, the results are similar to the results currently in the NumPyro Tutorials section on time series forecasting -- where the separate function is used to describe the model for the inference stage (sgt) and forecast stage (sgt_forecast). However, when I write the structural time series model - I am finding that the two approaches to producing forecasts produce different results. From what I can tell, the forecasts using the Predictive class do not seem to work correctly. I am wondering if the new scan function is behaving correctly with the Predictive class.

Structural Time Series example

I have been using the airline-passenger.csv dataset for testing (https://github.com/jbrownlee/Datasets/blob/master/airline-passengers.csv). In addition to installing NumPyro from 4c9786d, I also have jupyter, matplotlib and pandas installed in the environment (using Anaconda). The code below is taken from a Jupyter notebook where the dash lines represent new cells. The below code is modified from the NumPyro SGT examples.

import os
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random, vmap
from jax.nn import softmax
import numpyro; numpyro.set_host_device_count(4)
import numpyro.distributions as dist
from numpyro.distributions import transforms
from numpyro.diagnostics import autocorrelation, hpdi
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.contrib.control_flow.scan import scan

#assert numpyro.__version__.startswith('0.2.4')

# ------------------------------------------------------------------------
# Read Data
dataset = pd.read_csv("airline-passengers.csv", index_col=0)
y_data = dataset.Passengers.values
data = y_data*jnp.ones(len(y_data))
data = jnp.log(data)
split = int(len(data)*0.9)
y_train, y_test = jnp.array(data[:split], dtype=jnp.float32), data[split:]
# ------------------------------------------------------------------------
# Params
num_seasons = 12
# ------------------------------------------------------------------------
# Model
def locallinear(y=None, num_seasons=12, future=0):
    
    N = y.shape[0]
    duration = N + future

    # empirical
    sd_y = jnp.std(y)
    y_init = y[0]
    dy_init = y[1] - y[0]
    max_y = jnp.abs(jnp.max(y))
    cauchy_sd = max_y * 0.01
    
    # priors
    season = numpyro.sample('season_init', dist.Cauchy(0, jnp.abs(y[:num_seasons]) * 0.3))
    season = jnp.concatenate([season[1:], season[:1]], axis=0)

    level = numpyro.sample('level_init', dist.Normal(y_init, sd_y))    
    slope = numpyro.sample('slope_init', dist.Normal(dy_init, sd_y))
    
    level_scale = numpyro.sample('level_scale', dist.HalfCauchy(cauchy_sd))
    slope_scale = numpyro.sample('slope_scale', dist.HalfCauchy(cauchy_sd))
    y_scale = numpyro.sample('y_scale', dist.HalfCauchy(cauchy_sd))

    # state transition
    
    def transition(state, t):
        level, slope, season = state
        
        season_t = -1 * jnp.sum(season[:-1])
        season = jnp.concatenate([season[1:], season[:1]], axis=0)
        
        level = numpyro.sample('level', dist.Normal(level + slope, level_scale))
        slope = numpyro.sample('slope', dist.Normal(slope, slope_scale))

        # combined state
        z_t = level + season_t
 
        return (level, slope, season), z_t        

    (level, slope, season), z_t = scan(transition, (level, slope, season), jnp.arange(1, duration))
    
    # observation equation
    if future == 0:
        y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale), obs=y[1:])
    else:
        z_t = z_t[y.shape[0] - 1:]
        assert z_t.shape[0] == future
        y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale))
    return level, slope, season

# ------------------------------------------------------------------------
# Inference
kernel = NUTS(locallinear)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(1), y_train, num_seasons=num_seasons)
mcmc.print_summary()
samples = mcmc.get_samples()

# ------------------------------------------------------------------------
# Forecast using approach here http://pyro.ai/numpyro/time_series_forecasting.html
def locallinear_forecast(future, sample, y, level, slope, season):

    y_scale = sample["y_scale"]
    level_scale = sample['level_scale']
    slope_scale = sample['slope_scale']
    #yfs = [0] * future 
    for t in range(future):
        # state transition
        season_t = -1 * jnp.sum(season[:-1])
        season = jnp.concatenate([season[1:], season[:1]], axis=0)
        level = numpyro.sample('level{}'.format(t), dist.Normal(level + slope, level_scale))
        slope = numpyro.sample('slope{}'.format(t), dist.Normal(slope, slope_scale))
        z_t = level + season_t
        # observation
        yf = numpyro.sample("yf[{}]".format(t), dist.Normal(z_t, y_scale))
        #yfs[t] = yf
        
def forecast(future, rng_key, sample, y, num_seasons):
    level, slope, season = handlers.substitute(locallinear, sample)(y, num_seasons)
    forecast_model = handlers.seed(locallinear_forecast, rng_key)
    forecast_trace = handlers.trace(forecast_model).get_trace(future, sample, y, level, slope, season)
    results = [forecast_trace["yf[{}]".format(t)]["value"]
               for t in range(future)]
    return jnp.stack(results, axis=0)

rng_keys = random.split(random.PRNGKey(1), samples["y_scale"].shape[0])
forecast_marginal = vmap(lambda rng_key, sample: forecast(
    len(y_test), rng_key, sample, y_train, num_seasons=num_seasons))(rng_keys, samples)

# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3) 
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50))

# ------------------------------------------------------------------------
# Forecast using approach here https://github.com/pyro-ppl/numpyro/blob/87e94c0c988e0092ff0b29f4651ffcd022d68bec/notebooks/source/time_series_forecasting.ipynb

predictive = Predictive(locallinear, samples, return_sites=["y"])
forecast_marginal = predictive(random.PRNGKey(1), y_train, num_seasons=num_seasons, future=len(y_test))["y"]

# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3) 
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50))

One further reason I think it may be related to the new scan function is that when I write the same model without sample statements inside of the transition function called by scan, both methods of producing the forecasts give nearly identical results.

import os
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import pandas as pd
import jax.numpy as jnp
from jax import lax, random, vmap
from jax.nn import softmax
import numpyro; numpyro.set_host_device_count(4)
import numpyro.distributions as dist
from numpyro.distributions import transforms
from numpyro.diagnostics import autocorrelation, hpdi
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.contrib.control_flow.scan import scan

#assert numpyro.__version__.startswith('0.2.4')

# ------------------------------------------------------------------------
# Read Data
dataset = pd.read_csv("airline-passengers.csv", index_col=0)
y_data = dataset.Passengers.values
data = y_data*jnp.ones(len(y_data))
data = jnp.log(data)
split = int(len(data)*0.9)
y_train, y_test = jnp.array(data[:split], dtype=jnp.float32), data[split:]
# ------------------------------------------------------------------------
# Params
num_seasons = 12
# ------------------------------------------------------------------------
# Model
def locallinear(y=None, num_seasons=12, future=0):
    
    N = y.shape[0]
    duration = N + future

    # empirical
    sd_y = jnp.std(y)
    y_init = y[0]
    dy_init = y[1] - y[0]
    max_y = jnp.abs(jnp.max(y))
    cauchy_sd = max_y * 0.01
    
    # priors
    season = numpyro.sample('season_init', dist.Cauchy(0, jnp.abs(y[:num_seasons]) * 0.3))
    season = jnp.concatenate([season[1:], season[:1]], axis=0)

    level = numpyro.sample('level_init', dist.Normal(y_init, sd_y))    
    slope = numpyro.sample('slope_init', dist.Normal(dy_init, sd_y))
    
    #level_scale = numpyro.sample('level_scale', dist.HalfCauchy(cauchy_sd))
    #slope_scale = numpyro.sample('slope_scale', dist.HalfCauchy(cauchy_sd))
    y_scale = numpyro.sample('y_scale', dist.HalfCauchy(cauchy_sd))

    # state transition
    
    def transition(state, t):
        level, slope, season = state
        
        season_t = -1 * jnp.sum(season[:-1])
        season = jnp.concatenate([season[1:], season[:1]], axis=0)
        
        #level = numpyro.sample('level', dist.Normal(level + slope, level_scale))
        #slope = numpyro.sample('slope', dist.Normal(slope, slope_scale))
        level = level + slope

        # combined state
        z_t = level + season_t
 
        return (level, slope, season), z_t        

    (level, slope, season), z_t = scan(transition, (level, slope, season), jnp.arange(1, duration))
    
    # observation equation
    if future == 0:
        y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale), obs=y[1:])
    else:
        z_t = z_t[y.shape[0] - 1:]
        assert z_t.shape[0] == future
        y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale))
    return level, slope, season

# ------------------------------------------------------------------------
# Inference
kernel = NUTS(locallinear)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(1), y_train, num_seasons=num_seasons)
mcmc.print_summary()
samples = mcmc.get_samples()

# ------------------------------------------------------------------------
# Forecast using approach here http://pyro.ai/numpyro/time_series_forecasting.html
def locallinear_forecast(future, sample, y, level, slope, season):

    y_scale = sample["y_scale"]
    #level_scale = sample['level_scale']
    #slope_scale = sample['slope_scale']
    #yfs = [0] * future 
    for t in range(future):
        # state transition
        season_t = -1 * jnp.sum(season[:-1])
        season = jnp.concatenate([season[1:], season[:1]], axis=0)
        #level = numpyro.sample('level{}'.format(t), dist.Normal(level + slope, level_scale))
        level = level + slope
        #slope = numpyro.sample('slope{}'.format(t), dist.Normal(slope, slope_scale))
        z_t = level + season_t
        # observation
        yf = numpyro.sample("yf[{}]".format(t), dist.Normal(z_t, y_scale))
        #yfs[t] = yf
        
def forecast(future, rng_key, sample, y, num_seasons):
    level, slope, season = handlers.substitute(locallinear, sample)(y, num_seasons)
    forecast_model = handlers.seed(locallinear_forecast, rng_key)
    forecast_trace = handlers.trace(forecast_model).get_trace(future, sample, y, level, slope, season)
    results = [forecast_trace["yf[{}]".format(t)]["value"]
               for t in range(future)]
    return jnp.stack(results, axis=0)

rng_keys = random.split(random.PRNGKey(1), samples["y_scale"].shape[0])
forecast_marginal = vmap(lambda rng_key, sample: forecast(
    len(y_test), rng_key, sample, y_train, num_seasons=num_seasons))(rng_keys, samples)
        
def forecast(future, rng_key, sample, y, num_seasons):
    level, slope, season = handlers.substitute(locallinear, sample)(y, num_seasons)
    forecast_model = handlers.seed(locallinear_forecast, rng_key)
    forecast_trace = handlers.trace(forecast_model).get_trace(future, sample, y, level, slope, season)
    results = [forecast_trace["yf[{}]".format(t)]["value"]
               for t in range(future)]
    return jnp.stack(results, axis=0)

rng_keys = random.split(random.PRNGKey(1), samples["y_scale"].shape[0])
forecast_marginal = vmap(lambda rng_key, sample: forecast(
    len(y_test), rng_key, sample, y_train, num_seasons=num_seasons))(rng_keys, samples)

# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3) 
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50))

# ------------------------------------------------------------------------
# Forecast using approach here https://github.com/pyro-ppl/numpyro/blob/87e94c0c988e0092ff0b29f4651ffcd022d68bec/notebooks/source/time_series_forecasting.ipynb

predictive = Predictive(locallinear, samples, return_sites=["y"])
forecast_marginal = predictive(random.PRNGKey(1), y_train, num_seasons=num_seasons, future=len(y_test))["y"]

# ------------------------------------------------------------------------
# metrics
y_pred = jnp.median(forecast_marginal, axis=0)
print(y_pred)
sMAPE = jnp.mean(jnp.abs(jnp.exp(y_pred) - jnp.exp(y_test)) / jnp.abs(jnp.exp(y_pred) + jnp.exp(y_test))) * 200
msqrt = jnp.sqrt(jnp.mean((jnp.exp(y_pred) - jnp.exp(y_test)) ** 2))
print("sMAPE: {:.2f}%, rmse: {:.2f}".format(sMAPE, msqrt))
# ------------------------------------------------------------------------
# figures
plt.figure(figsize=(12, 8))
plt.plot(jnp.exp(data))
t_future = jnp.arange(len(y_data))[split:]
t_future_min = min(t_future)
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, jnp.exp(y_pred), lw=2)
plt.axvline(x=t_future_min, color='k')
plt.fill_between(t_future, jnp.exp(hpd_low), jnp.exp(hpd_high), alpha=0.3) 
p50 = jnp.percentile(forecast_marginal, q=50, axis=0)
plt.plot(t_future, jnp.exp(p50))

@fehiepsi
Copy link
Member Author

From what I can tell, the forecasts using the Predictive class do not seem to work correctly.

Whoa, we really love examples that shows something is wrong! I'll look deeper into your example. Thanks, @jawolf314!

@fehiepsi
Copy link
Member Author

fehiepsi commented Jun 18, 2020

@jawolf314 I think I know the reason. In the current implementation, we need the condition value (the value from posterior, which has the same size as the training data) to have the same size as the scan length, but we didn't raise an error if that is not the case. With the current implementation and with the help of namespace feature in #593, the following pattern will work for you

last_val, ys_train = scan(fn, init_val, xs_train)
_, ys_pred = scan(namespace(fn, 'pred'), last_val, xs_test)

then you can collect pred::level, pred::scope using Predictive. :)

But after getting up this morning, I think we can support the case: scan length > length of the condition value using lax.cond. Let me make the enhancement today and I'll ping you for that. In the meantime, could you try the above solution to verify that my assumption is right? Currently, we don't have namespace handler but you can just simply create a new fn with names of level, scope samples changed to level_pred, scope_pred or something like that.

Btw, if things work well, could you help us by contributing your example (like some in examples folder) or even better writing a small tutorial with the above model? Currently, we don't have the bandwidth to both develop new features and explore useful applications to illustrate the features that we added. So your contributions are very welcome and I believe they will be very helpful for other users. Thanks!!

@fehiepsi
Copy link
Member Author

fehiepsi commented Jun 18, 2020

@jawolf314 I have made an enhancement for your usage case. Please let me know if things work as expected for you. :)

(I might make a couple of enhancements to get rid of some grosses in the code after discussing with other devs - but I'll make sure that things will work as expected for you. If you observe something still wrong, please let us know.)

@jawolf314
Copy link

jawolf314 commented Jun 18, 2020

Thanks @fehiepsi!

@jawolf314 I think I know the reason. In the current implementation, we need the condition value (the value from posterior, which has the same size as the training data) to have the same size as the scan length, but we didn't raise an error if that is not the case. With the current implementation and with the help of namespace feature in #593, the following pattern will work for you

last_val, ys_train = scan(fn, init_val, xs_train)
_, ys_pred = scan(namespace(fn, 'pred'), last_val, xs_test)

then you can collect pred::level, pred::scope using Predictive. :)

But after getting up this morning, I think we can support the case: scan length > length of the condition value using lax.cond. Let me make the enhancement today and I'll ping you for that. In the meantime, could you try the above solution to verify that my assumption is right? Currently, we don't have namespace handler but you can just simply create a new fn with names of level, scope samples changed to level_pred, scope_pred or something like that.

The results using both methods give similar results after updating the model to as you suggested. I just saw you pushed d7583fc. Will test the original version of the locallinear model on the updated code and comment shortly.

def locallinear(y=None, num_seasons=12, future=0):
    
    N = y.shape[0]
    duration = N + future
    #print(N, duration)

    # empirical
    sd_y = jnp.std(y)
    y_init = y[0]
    dy_init = y[1] - y[0]
    max_y = jnp.abs(jnp.max(y))
    cauchy_sd = max_y * 0.01
    
    # priors
    season = numpyro.sample('season_init', dist.Cauchy(0, jnp.abs(y[:num_seasons]) * 0.3))
    season = jnp.concatenate([season[1:], season[:1]], axis=0)

    level = numpyro.sample('level_init', dist.Normal(y_init, sd_y))    
    slope = numpyro.sample('slope_init', dist.Normal(dy_init, sd_y))
    
    level_scale = numpyro.sample('level_scale', dist.HalfCauchy(cauchy_sd))
    slope_scale = numpyro.sample('slope_scale', dist.HalfCauchy(cauchy_sd))
    y_scale = numpyro.sample('y_scale', dist.HalfCauchy(cauchy_sd))

    # state transition
    
    def transition(state, t):
        level, slope, season = state
        
        season_t = -1 * jnp.sum(season[:-1])
        season = jnp.concatenate([season[1:], season[:1]], axis=0)
        
        level = numpyro.sample('level', dist.Normal(level + slope, level_scale))
        slope = numpyro.sample('slope', dist.Normal(slope, slope_scale))

        # combined state
        z_t = level + season_t
 
        return (level, slope, season), z_t        

    # added https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
    def transition_pred(state, t):
        level, slope, season = state
        
        season_t = -1 * jnp.sum(season[:-1])
        season = jnp.concatenate([season[1:], season[:1]], axis=0)
        
        level = numpyro.sample('level_pred', dist.Normal(level + slope, level_scale))
        slope = numpyro.sample('slope_pred', dist.Normal(slope, slope_scale))

        # combined state
        z_t = level + season_t
 
        return (level, slope, season), z_t   

    
    # updated following https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
    (level, slope, season), z_t = scan(transition, (level, slope, season), jnp.arange(1, N))
    # added following https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
    if future != 0:
        _, z_t_pred = scan(transition_pred, (level, slope, season), jnp.arange(N, duration))
    
    # observation equation
    if future == 0:
        y_pred = numpyro.sample('y', dist.Normal(z_t, y_scale), obs=y[1:])
    else:
        # updated following https://github.com/pyro-ppl/numpyro/pull/595#issuecomment-646088220
        assert z_t_pred.shape[0] == future
        y_pred = numpyro.sample('y', dist.Normal(z_t_pred, y_scale))        
        
    return level, slope, season

Btw, if things work well, could you help us by contributing your example (like some in examples folder) or even better writing a small tutorial with the above model? Currently, we don't have the bandwidth to both develop new features and explore useful applications to illustrate the features that we added. So your contributions are very welcome and I believe they will be very helpful for other users. Thanks!!

Yes, thank you. I am happy to contribute an example.

@jawolf314
Copy link

@fehiepsi d7583fc is working great. Using Predictive class gives similar results to the approach using separate model functions for inference and forecasting.

@fehiepsi
Copy link
Member Author

Thanks for the feedback, @jawolf314! Glad that it works. 👯‍♂️

@fehiepsi
Copy link
Member Author

This PR is ready to review. I am pretty happy with the current implementation (all gross points that I put in the comments of previous commits are removed now). :D

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a closer look at the distribution flattening/ unflattening logic. This is looking really good, and is a very important feature to have. Look forward to merging this soon.

@@ -214,7 +214,7 @@ def __init__(self, df, validate_args=None):

@copy_docs_from(Distribution)
class GaussianRandomWalk(Distribution):
arg_constraints = {'num_steps': constraints.positive_integer, 'scale': constraints.positive}
arg_constraints = {'scale': constraints.positive, 'num_steps': constraints.positive_integer}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this change?

EDIT: I see why we have this. Note my comment below.

"""
This primitive scans a function over the leading array axes of
`xs` while carrying along state. See :func:`jax.lax.scan` for more
information.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi - What currently are the limitations of scan? e.g. distributions, compatibility with other handlers, etc. Maybe we can add a note to that effect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current limitation is:

  • It only records sample and deterministic sites in f, not param. The main reason is we need to have logics to flatten/unflatten constraint field and I was lazy to do so.
  • I think most handlers (e.g. scale, mask) will work because we do apply_stack for scanned messages. The one that won't work is plate, where users will need to move plate statements inside f. I'll add a note to it. Do you suspect something else that won't work out of the box?
  • For distributions, I have raised errors for not supported distributions and told users to report the issue to us.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is totally fine to not support param, to begin with. I think most effect handlers should work out of the box, but let me take a look in detail.

We can add a disclaimer.

.. warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, sample and deterministic sites within the scan body are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Neeraj, I'll add it.

cls.tree_flatten,
cls.tree_unflatten)

def tree_flatten(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is probably better to not expose a default implementation. The reason is that this relies on an implementation detail of cpython (python 3.6+) in that dictionary keys will be sorted. Can we raise NotImplementedError for both of these? If the default works for a lot of distributions, we could use an OrderedDict for arg_constraints. This is however an implicit ordering constraint that users will need to know if they are writing their own distribution classes, so it will be best to just raise an error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default implementation works because most distributions only need ndarray args at their constructor (some special cases such as event_dim in Delta or covariance_matrix=None in MVN or dimension=4 in LKJ,...). I will try to use inspect module here to guarantee that the order of args is matched.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, another possibility is to sort by key here:

    return tuple(getattr(self, param) for param in sorted(self.arg_constraints.keys())), None

and then, for unflatten:

    return cls(**dict(zip(sorted(self.arg_constraints.keys()), params)))

Would that work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, nice trick at the unflatten! Yes, that should work. Many thanks!

numpyro/distributions/continuous.py Show resolved Hide resolved
# this branch happens when substitute_fn is init_strategy,
# where we apply init_strategy to each element in the scanned series
return value
elif value_ndim == fn_ndim + 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the time dim is always taken to be the first dim, even in the presence of outermost plate dims? I think that should be fine, but could you add a note in the docstring/comment, if that's the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, time dim will be the first dim. I think that is covered in the note in the docs (plate cannot be used as a context for scan). We also raise an error if users do

with plate(...):
    scan(...)

else:
raise RuntimeError(f"Substituted value for site {site['name']} "
"requires length greater than or equal to scan length."
f" Expected length >= {length}, but got {shape[0]}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be lesser than scan length (length of time dim). This seems to suggest otherwise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops... thanks for pointing it out!

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just had some minor clarifying comments. Happy to merge otherwise! 🎉

@neerajprad neerajprad merged commit b28fea9 into pyro-ppl:master Jun 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Predictive distribution fails on model with lax.scan
3 participants