Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues related to using predictive methods multiple times on the same model #282

Closed
damonbayer opened this issue Jul 19, 2024 · 2 comments
Closed
Labels
invalid This doesn't seem right pyrenew related to pyrenew internals

Comments

@damonbayer
Copy link
Collaborator

damonbayer commented Jul 19, 2024

This (and other similar exercises) leads to an error.

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal, assert_raises
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
    InfectionInitializationProcess,
    Infections,
    InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import RtRandomWalkProcess

pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")
I0 = InfectionInitializationProcess(
    "I0_initialization",
    DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
    InitializeInfectionsZeroPad(n_timepoints=gen_int.size()),
    t_unit=1,
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
rt = RtRandomWalkProcess(
    Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
    Rt_transform=t.ExpTransform().inv,
    Rt_rw_dist=dist.Normal(0, 0.025),
)
model = RtInfectionsRenewalModel(
    I0_rv=I0,
    gen_int_rv=gen_int,
    latent_infections_rv=latent_infections,
    infection_obs_process_rv=observed_infections,
    Rt_process_rv=rt,
)

n_tp = 30

model.prior_predictive(
    numpyro_predictive_args={"num_samples": 20},
    n_timepoints_to_simulate=n_tp,
)

model.prior_predictive(
    numpyro_predictive_args={"num_samples": 20},
    n_timepoints_to_simulate=n_tp,
)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <lambda> at /Users/damon/Library/Caches/pypoetry/virtualenvs/pyrenew-GjeTh4Fr-py3.12/lib/python3.12/site-packages/jax/_src/lax/control_flow/loops.py:2111 traced for scan.
------------------------------
@damonbayer damonbayer added invalid This doesn't seem right pyrenew related to pyrenew internals labels Jul 19, 2024
@dylanhmorris
Copy link
Collaborator

This should be resolved by pyro-ppl/numpyro#1843

@dylanhmorris
Copy link
Collaborator

This is resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
invalid This doesn't seem right pyrenew related to pyrenew internals
Projects
None yet
Development

No branches or pull requests

2 participants