Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AR2 model example question #1346

Closed
hesenp opened this issue Feb 23, 2022 · 7 comments
Closed

AR2 model example question #1346

hesenp opened this issue Feb 23, 2022 · 7 comments

Comments

@hesenp
Copy link
Contributor

hesenp commented Feb 23, 2022

I was reading this example doc about auto-regressive time series modeling using numpyro, and realized there's some confusion about AR(2) concept here.

In classical text book sense, $$y_t$$ would not only carry forward the deterministic terms, but also random terms from previous results, e.g. $$e_t$$, $$e_{t-1}$$ etc.

However, in the example code, we have generated all the means of the time series using scan function deterministically, and only layered the noise terms on top.

This might cause the $$\sigma$$ values to be under/over estimated compared with the actual AR2 specification. Please let me know if this understanding is right?

@MarcoGorelli
Copy link
Contributor

Hi @hesenp - how should the example be written?

I was using https://mc-stan.org/docs/2_29/stan-users-guide/autoregressive.html as a reference

@omarfsosa
Copy link
Contributor

omarfsosa commented Feb 23, 2022

Hmmm yeah I believe @hesenp is correct. The transition function should include a sample statement so that the noise is carried forward.
This is how I'd normally write an AR2 model:

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro import sample
from numpyro.contrib.control_flow import scan
from numpyro.infer import Predictive, NUTS, MCMC


def ar2(num_timesteps, y0, y1, y=None):
    """
    An auto regressive (K=2) model.
    
    Parameters
    ----------
    num_timesteps: int, positive
        The total number of timesteps to model
    
    y: ndarray, shape (num_timesteps,)
        The observed values beyond y0 and y1
    
    y0, y1: floats
        The initial values of the process
    """
    a1 = sample("a1", dist.Normal())
    a2 = sample("a2", dist.Normal())
    const = sample("const", dist.Normal())
    sigma = sample("sigma", dist.Exponential())

    def transition(carry, _):
        y_prev, y_prev_prev = carry
        m_t = const + a1 * y_prev + a2 * y_prev_prev
        y_t = sample("y", dist.Normal(m_t, sigma))
        carry = (y_t, y_prev)
        return carry, None

    timesteps = jnp.arange(num_timesteps)
    init = (y0, y1)
    with numpyro.handlers.condition(data={"y": y}):
        scan(transition, init, timesteps)

# Usage in prior simulation
num_timesteps = 40
y0, y1 = 0.3, -0.1

prior = Predictive(ar2, num_samples=10)
prior_rng = jax.random.PRNGKey(0)
prior_samples = prior(prior_rng, num_timesteps, y0, y1)

(Note that I'm using numpyro's scan instead of lax.scan)

@MarcoGorelli
Copy link
Contributor

Ah, thanks @omarfsosa !

@hesenp
Copy link
Contributor Author

hesenp commented Feb 23, 2022

Thanks a lot for sharing this @omarfsosa . I was thinking of doodling one example for this and then went to sleep. So happy that I wake up with an example here already.

@omarfsosa
Copy link
Contributor

@MarcoGorelli Do you want open the PR to update the example? I'm also happy to do it if you don't have time atm

@MarcoGorelli
Copy link
Contributor

Hey @omarfsosa ! I am quite busy at the moment so wouldn't have a chance to update this til at least next week, so am happy to let you (as you posted here the correct version of the example) or @hesenp (as he first noticed the mistake) author the PR

@hesenp
Copy link
Contributor Author

hesenp commented Feb 27, 2022

yeah I can give a shot. Let me see if i can pull this off tonight.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants