Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tracer error in blocked AutoGuide #1753

Closed
amifalk opened this issue Feb 29, 2024 · 12 comments · Fixed by #1759
Closed

tracer error in blocked AutoGuide #1753

amifalk opened this issue Feb 29, 2024 · 12 comments · Fixed by #1759
Labels
bug Something isn't working

Comments

@amifalk
Copy link
Contributor

amifalk commented Feb 29, 2024

I came across a very strange bug while trying to vmap the SVI class (in order to parallelize model training across multiple initializations + across different datasets of the same shape).

A tracer error occurs, but only if the AutoGuide has a site blocked out and there is also a deterministic site in the model. I wonder if this is related to #1657 ?

import jax
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import block, seed
from numpyro.infer import SVI, TraceEnum_ELBO
from numpyro.infer.autoguide import AutoDelta

def model():
    a = numpyro.sample('a', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))

# -- this works --
keys = random.split(random.PRNGKey(0), 2)

optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
svi = SVI(model, guide, optimizer, loss=TraceEnum_ELBO())

mapped_state = jax.vmap(svi.init)(keys)

def model_w_deterministic():
    a = numpyro.sample('a', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))
    
    numpyro.deterministic('test', a)

# -- this fails --
optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
svi = SVI(model_w_deterministic, guide, optimizer, loss=TraceEnum_ELBO())

mapped_state = jax.vmap(svi.init)(keys)
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was body_fn at [/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:358](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:358) traced for while_loop.
------------------------------
The leaked intermediate value was created on line [/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8) (seed.process_message). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:105:19](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:105:19) (Messenger.__call__)
<ipython-input-9-38f6b9474998>:17:8 (model_w_deterministic)
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:222:10](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:222:10) (sample)
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:47:8](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:47:8) (apply_stack)
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8) (seed.process_message)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError```
@amifalk amifalk changed the title memory leak in blocked AutoGuide tracer error in blocked AutoGuide Feb 29, 2024
@fehiepsi
Copy link
Member

fehiepsi commented Mar 1, 2024

Hi @amifalk, those autoguides are not designed to be composed with vmap after the construction because it needs initialization (to inspect the model and generate something like prototype_trace etc.). Something like this will work

def init(...):
    guide = AutoDelta(...)
    return guide.init(...)

init_state = jax.vmap(init)(...)

@fehiepsi fehiepsi added the question Further information is requested label Mar 1, 2024
@amifalk
Copy link
Contributor Author

amifalk commented Mar 1, 2024

I think there's still an issue here. When svi.init is called, it initializes both the model and the guide for the first time, which should set up the prototype trace. Batched model fitting works with AutoGuides if I vmap the SVI methods in all cases except when there is both a deterministic site and the guide is based on a blocked model.

The suggested approach yields the same error as before (though AutoGuides are not registered as pytrees so they cannot be returned after calling vmap).

def guide_init(rng_seed):
   guide = AutoDelta(block(seed(model, rng_seed=rng_seed), hide=['b']))
   seed(guide, rng_seed=rng_seed)()
   
   return 

keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init)(keys) # this works

def guide_init_deterministic(rng_seed):
   guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=rng_seed), hide=['b']))
   seed(guide, rng_seed=rng_seed)()
   
   return 

keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init_deterministic)(keys) # tracer error

@amifalk
Copy link
Contributor Author

amifalk commented Mar 2, 2024

@fehiepsi I've traced the source to this while loop. If I set _DISABLE_CONTROL_FLOW_PRIM = True, vmapping the svi.init method works. However, vmapping the guide initialization yields a new error in the while loop:

This BatchTracer with object id 140305711093520 was created on line:
  /home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:357:15 (find_valid_initial_params.<locals>.cond_fn)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

@fehiepsi
Copy link
Member

fehiepsi commented Mar 2, 2024

If we use Python while loop, then the condition needs to be a Python value like True or False. having a jax object there won't work. What is your usage case by the way?

@amifalk
Copy link
Contributor Author

amifalk commented Mar 2, 2024

I have a blocked model with a deterministic site that I'm trying to perform some simulation studies on. I want to see how variations in the structure of the dataset / model hyperparameters affect the performance, and I also want to be able to select the best result over multiple initializations. It's very slow to do this sequentially (for a small grid of hyperparams it took around 40 minutes), but after vmapping/pmapping with GPU I can get the entire grid to run in parallel. In my case it reduced the fitting time to 7 seconds.

Unfortunately, if I try to vmap the blocked model with deterministic sites present, it throws this error, so I have to instead recompute the deterministic sites at the end of model fitting.

In my case, I need to block the model to define an AutoGuide that is compatible with enumeration (blocking out the enumerated sites), but this would likely also be a problem for people using AutoGuideList.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 3, 2024

I think you can do something like

def run_svi(...):
    svi = ...
    svi_result = svi.run(...)
    return svi_result

svi_results = vmap(run_svi)(...)

@amifalk
Copy link
Contributor Author

amifalk commented Mar 3, 2024

Unfortunately this still seems to throw the same tracer error.

def run_svi(key):
    optimizer = numpyro.optim.Adam(step_size=.01)
    guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    return svi.run(key, 100, progress_bar=False)

def run_svi_deterministic(key):
    optimizer = numpyro.optim.Adam(step_size=.01)
    guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    return svi.run(key, 100, progress_bar=False)

keys = random.split(random.PRNGKey(0), 2)

jax.vmap(run_svi)(keys) # works
jax.vmap(run_svi_deterministic)(keys) # tracer error from the while loop in find_valid_initial_params

@amifalk
Copy link
Contributor Author

amifalk commented Mar 3, 2024

Can we make it so that AutoGuides only collect non-enumerated model sample sites? This wouldn't fix the problem for all blocked models, but it would make collecting deterministic sites possible under batched svi for my use-case.

I think this would only have to be a one-liner change here-ish where we just ignore sample sites in the prototype trace that have site['infer'].get('enumerate') == True. That would also make the syntax for defining AutoGuides for enumerated models much simpler, e.g. just AutoGuide(model), instead of
AutoGuide(block(seed(model, rng_seed=0), hide=["enumerated_site_1", "enumerated_site_2", ...]))

@fehiepsi
Copy link
Member

Thanks @amifalk! There is indeed leakage here with the seed handler. I haven't been able to figure out why yet. Posting here for reference

import numpyro
import numpyro.distributions as dist
import jax

def model():
    return numpyro.sample('a', dist.Normal(0, 1))    

def run(key):
    return numpyro.infer.util.initialize_model(key, numpyro.handlers.seed(model, rng_seed=0))[0]

with jax.checking_leaks():
    jax.jit(run)(jax.random.PRNGKey(0))

@fehiepsi fehiepsi added bug Something isn't working and removed question Further information is requested labels Mar 12, 2024
@amifalk
Copy link
Contributor Author

amifalk commented Mar 12, 2024

@fehiepsi With that example, I was able to narrow the source of the bug further - thanks! The while loop of _find_valid_params closes over the seeded model, but it also traces the model during its calls to potential_fn. I think the fact that the trace is seeing a rng key from the global call to seed is causing the error. Here's a minimal example.

import numpyro
import numpyro.distributions as dist
import jax

def model():
    return numpyro.sample('a', dist.Normal(0, 1))    

def run(key):
    seeded = numpyro.handlers.seed(model, rng_seed=0)

    def cond_fn(state):
        i, num = state
        return i < 10 

    def body_fn(state):
        i, num = state
        
        numpyro.handlers.trace(seeded).get_trace() # this references the global rng values in a jitted context
        # equivalently num = numpyro.handlers.trace(seeded).get_trace()['a']['value'] will raise an error        
        return (i + 1, num)

    return jax.lax.while_loop(cond_fn, body_fn, (0, 0))
    
with jax.checking_leaks():
    jax.jit(run)(jax.random.PRNGKey(0))

You can also verify this by replacing potential_fn in numpyro.infer.util.find_valid_initial_params with a placeholder that just returns a constant number.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 12, 2024

I think we figured it out. thanks for the examples!

seed(model) is an instance of a seed class which has mutable state. A fix for it is to close the seeded model into a function like

def seeded_model(*args, **kwargs):
    return seed(model, rng_seed=random.PRNGKey(0))(*args, **kwargs)

This way each time we call the model, a new instance of the seed handler will be created. Could you check if it works for your usage case? I'll think of a long term solution (maybe improve docstring for this).

@amifalk
Copy link
Contributor Author

amifalk commented Mar 13, 2024

Yes, this fixed it! Not sure if there's any interest in adding to NumPyro, but here's the pattern for batching SVI: https://gist.github.com/amifalk/eb377a243b046105dc00beda79441b22

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants