-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tracer error in blocked AutoGuide #1753
Comments
Hi @amifalk, those autoguides are not designed to be composed with vmap after the construction because it needs initialization (to inspect the model and generate something like prototype_trace etc.). Something like this will work
|
I think there's still an issue here. When The suggested approach yields the same error as before (though AutoGuides are not registered as pytrees so they cannot be returned after calling vmap). def guide_init(rng_seed):
guide = AutoDelta(block(seed(model, rng_seed=rng_seed), hide=['b']))
seed(guide, rng_seed=rng_seed)()
return
keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init)(keys) # this works
def guide_init_deterministic(rng_seed):
guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=rng_seed), hide=['b']))
seed(guide, rng_seed=rng_seed)()
return
keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init_deterministic)(keys) # tracer error |
@fehiepsi I've traced the source to this while loop. If I set
|
If we use Python while loop, then the condition needs to be a Python value like True or False. having a jax object there won't work. What is your usage case by the way? |
I have a blocked model with a deterministic site that I'm trying to perform some simulation studies on. I want to see how variations in the structure of the dataset / model hyperparameters affect the performance, and I also want to be able to select the best result over multiple initializations. It's very slow to do this sequentially (for a small grid of hyperparams it took around 40 minutes), but after vmapping/pmapping with GPU I can get the entire grid to run in parallel. In my case it reduced the fitting time to 7 seconds. Unfortunately, if I try to vmap the blocked model with deterministic sites present, it throws this error, so I have to instead recompute the deterministic sites at the end of model fitting. In my case, I need to block the model to define an AutoGuide that is compatible with enumeration (blocking out the enumerated sites), but this would likely also be a problem for people using AutoGuideList. |
I think you can do something like def run_svi(...):
svi = ...
svi_result = svi.run(...)
return svi_result
svi_results = vmap(run_svi)(...) |
Unfortunately this still seems to throw the same tracer error. def run_svi(key):
optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
return svi.run(key, 100, progress_bar=False)
def run_svi_deterministic(key):
optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
return svi.run(key, 100, progress_bar=False)
keys = random.split(random.PRNGKey(0), 2)
jax.vmap(run_svi)(keys) # works
jax.vmap(run_svi_deterministic)(keys) # tracer error from the while loop in find_valid_initial_params |
Can we make it so that AutoGuides only collect non-enumerated model sample sites? This wouldn't fix the problem for all blocked models, but it would make collecting deterministic sites possible under batched svi for my use-case. I think this would only have to be a one-liner change here-ish where we just ignore sample sites in the prototype trace that have |
Thanks @amifalk! There is indeed leakage here with the seed handler. I haven't been able to figure out why yet. Posting here for reference import numpyro
import numpyro.distributions as dist
import jax
def model():
return numpyro.sample('a', dist.Normal(0, 1))
def run(key):
return numpyro.infer.util.initialize_model(key, numpyro.handlers.seed(model, rng_seed=0))[0]
with jax.checking_leaks():
jax.jit(run)(jax.random.PRNGKey(0)) |
@fehiepsi With that example, I was able to narrow the source of the bug further - thanks! The while loop of import numpyro
import numpyro.distributions as dist
import jax
def model():
return numpyro.sample('a', dist.Normal(0, 1))
def run(key):
seeded = numpyro.handlers.seed(model, rng_seed=0)
def cond_fn(state):
i, num = state
return i < 10
def body_fn(state):
i, num = state
numpyro.handlers.trace(seeded).get_trace() # this references the global rng values in a jitted context
# equivalently num = numpyro.handlers.trace(seeded).get_trace()['a']['value'] will raise an error
return (i + 1, num)
return jax.lax.while_loop(cond_fn, body_fn, (0, 0))
with jax.checking_leaks():
jax.jit(run)(jax.random.PRNGKey(0)) You can also verify this by replacing |
I think we figured it out. thanks for the examples! seed(model) is an instance of a
This way each time we call the model, a new instance of the seed handler will be created. Could you check if it works for your usage case? I'll think of a long term solution (maybe improve docstring for this). |
Yes, this fixed it! Not sure if there's any interest in adding to NumPyro, but here's the pattern for batching SVI: https://gist.github.com/amifalk/eb377a243b046105dc00beda79441b22 |
I came across a very strange bug while trying to vmap the
SVI
class (in order to parallelize model training across multiple initializations + across different datasets of the same shape).A tracer error occurs, but only if the AutoGuide has a site blocked out and there is also a deterministic site in the model. I wonder if this is related to #1657 ?
The text was updated successfully, but these errors were encountered: