-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XlaRuntimeError: Outside call <jax.experimental.host_callback._CallbackWrapper object ... > #12
Comments
Thanks for raising this! The code assumes that iterations start at 0 (which is when the progress bar is first defined), so this is why the code fails. The alternative (but undocumented!) way of using the progress bar for scans (which allows scanning over any iterable) is to pass in the tuple In the example you gave this corresponds to writing: last_number, all_numbers = lax.scan(step, 0, (jnp.arange(0, n_stop-n_start), jnp.arange(n_start, n_stop))) I realise now that this is perhaps a bit convoluted and that this could rather be done behind the scenes, so that the user doesn’t need to do anything different. Thanks again for flagging this; I’ll look into doing this fix soon! Let me know if this solution of passing in a tuple doesn't work for you. If it works, we could document this feature in the README.md until we do a nicer fix. |
Yes, the alternative work-around also works for me. Though in this example Thanks! |
Yes that's right! So the step function would have to be modified (assuming here that you're using def step(carry, x):
_, y = x
return carry + 1 + y, carry + 1 + y |
I think the following is related. On jax 0.4.28 I get |
I've actually been meaning to file a similar issue, but I was trying to track down the root cause. I've seen this error, but then everything runs ok when I ruyn with a larger print-rate (i.e. a longer period between updates of the progress bar). Made me suspect that this was some async issue, but still need to pin down the exact circumstances. This is with JAX 0.4.26. Also a progress bar that begins at 0. |
I have found the exact same thing. But I could not consistently reproduce the error / successful runs. For example, I use |
Ah thanks for the info, was having a similar thought! Wonder if there is some way to flush the update at the first step? |
I wonder if the |
Good spot! I just added |
Thanks for releasing and maintaining this little library; I use it quite a lot in my daily JAX work. When I perform a decorated
lax.scan
over an iterator that doesn't start at zero, I get anOutside call
error. See for instance this example:Which raises:
If I replace
jnp.arange(n_start, n_stop)
withjnp.arange(n_stop - n_start)
and addn_start
at the end, I get the desired behaviour ofall_numbers
running fromn_start
ton_stop
with a workingtqdm
progress bar, which is an easy work-around to implement, but I wonder why it doesn't work in a more conventional way.Anyway, since there is an easy work-around, this should be a low-priority issue.
The text was updated successfully, but these errors were encountered: