Skip to content

Commit

Permalink
Make figure in AR2 example reproducible (#1816)
Browse files Browse the repository at this point in the history
* Add code to reproduce ar2 plot

* use m_t for plot

* make pre-commit happy

* implement mean tracking for loop version

* restore mean tracking in scan version
  • Loading branch information
damonbayer authored Jun 23, 2024
1 parent 0924135 commit 491a0cd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
Binary file modified docs/source/_static/img/examples/ar2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 28 additions & 4 deletions examples/ar2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import os
import time

import matplotlib.pyplot as plt

import jax
from jax import random
import jax.numpy as jnp
Expand All @@ -54,13 +56,15 @@ def transition(carry, _):
m_t = const + alpha_1 * y_prev + alpha_2 * y_prev_prev
y_t = numpyro.sample("y", dist.Normal(m_t, sigma))
carry = (y_t, y_prev)
return carry, None
return carry, m_t

timesteps = jnp.arange(y.shape[0] - 2)
init = (y[1], y[0])

with numpyro.handlers.condition(data={"y": y[2:]}):
scan(transition, init, timesteps)
_, mu = scan(transition, init, timesteps)

numpyro.deterministic("mu", mu)


def ar2_for_loop(y):
Expand All @@ -71,13 +75,16 @@ def ar2_for_loop(y):

y_prev = y[1]
y_prev_prev = y[0]

mu = []
for i in range(2, len(y)):
m_t = const + alpha_1 * y_prev + alpha_2 * y_prev_prev
mu.append(m_t)
y_t = numpyro.sample("y_{}".format(i), dist.Normal(m_t, sigma), obs=y[i])
y_prev_prev = y_prev
y_prev = y_t

numpyro.deterministic("mu", jnp.asarray(mu))


def run_inference(model, args, rng_key, y):
start = time.time()
Expand Down Expand Up @@ -110,7 +117,24 @@ def main(args):
# faster
model = ar2_scan

run_inference(model, args, rng_key, y)
samples = run_inference(model, args, rng_key, y)

# do prediction
mean_prediction = samples["mu"].mean(axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
ax.plot(t, y, color="blue", label="True values")
# plot mean prediction
# note that we can't make predictions for the first two points,
# because they don't have lagged values to use for prediction.
ax.plot(t[2:], mean_prediction, color="orange", label="Mean predictions")
ax.set(xlabel="time", ylabel="y", title="AR2 process")
ax.legend()

plt.savefig("ar2.png")


if __name__ == "__main__":
Expand Down

0 comments on commit 491a0cd

Please sign in to comment.