Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XlaRuntimeError: Outside call <jax.experimental.host_callback._CallbackWrapper object ... > #12

Open
martijnende opened this issue Mar 8, 2023 · 9 comments

Comments

@martijnende
Copy link

martijnende commented Mar 8, 2023

Thanks for releasing and maintaining this little library; I use it quite a lot in my daily JAX work. When I perform a decorated lax.scan over an iterator that doesn't start at zero, I get an Outside call error. See for instance this example:

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n_start = 5_000
n_stop = 10_000

@scan_tqdm(n_stop - n_start)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n_start, n_stop))

Which raises:

ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x7f8ba81b6760> threw exception 0.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Input In [5], in <cell line: 12>()
      8 @scan_tqdm(n_stop - n_start)
      9 def step(carry, x):
     10     return carry + 1, carry + 1
---> 12 last_number, all_numbers = lax.scan(step, 0, jnp.arange(n_start, n_stop))

    [... skipping hidden 7 frame]

File ~/miniconda3/envs/abyss/lib/python3.9/site-packages/jax/_src/dispatch.py:837, in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args)
    835     runtime_token = None
    836 else:
--> 837   out_flat = compiled.execute(in_flat)
    838 check_special(name, out_flat)
    839 out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: INTERNAL: Generated function failed: KeyError: 0

If I replace jnp.arange(n_start, n_stop) with jnp.arange(n_stop - n_start) and add n_start at the end, I get the desired behaviour of all_numbers running from n_start to n_stop with a working tqdm progress bar, which is an easy work-around to implement, but I wonder why it doesn't work in a more conventional way.

Anyway, since there is an easy work-around, this should be a low-priority issue.

@jeremiecoullon
Copy link
Owner

jeremiecoullon commented Mar 8, 2023

Thanks for raising this!

The code assumes that iterations start at 0 (which is when the progress bar is first defined), so this is why the code fails.

The alternative (but undocumented!) way of using the progress bar for scans (which allows scanning over any iterable) is to pass in the tuple (jnp.arange(0, num_iters), my_range) with my_range whatever you want to scan over.
So then lax.scan will scan over my_range, but the progress bar will use jnp.arange(0, num_iters).

In the example you gave this corresponds to writing:

last_number, all_numbers = lax.scan(step, 0, (jnp.arange(0, n_stop-n_start), jnp.arange(n_start, n_stop)))

I realise now that this is perhaps a bit convoluted and that this could rather be done behind the scenes, so that the user doesn’t need to do anything different.

Thanks again for flagging this; I’ll look into doing this fix soon!

Let me know if this solution of passing in a tuple doesn't work for you. If it works, we could document this feature in the README.md until we do a nicer fix.

@martijnende
Copy link
Author

Yes, the alternative work-around also works for me. Though in this example x now becomes a tuple, effectively representing some kind of zip of (jnp.arange(0, n_stop-n_start), jnp.arange(n_start, n_stop)).

Thanks!

@jeremiecoullon
Copy link
Owner

Though in this example x now becomes a tuple, effectively representing some kind of zip of (jnp.arange(0, n_stop-n_start), jnp.arange(n_start, n_stop))

Yes that's right! So the step function would have to be modified (assuming here that you're using x in step), and you should ignore the zero-th element of x:

def step(carry, x):
    _, y = x
    return carry + 1 + y, carry + 1 + y

@mdmould
Copy link
Contributor

mdmould commented May 14, 2024

I think the following is related. On jax 0.4.28 I get KeyError: 0 in _update_tqdm from tqdm_bars[0].update(arg), even when iterations begin at 0.

@zombie-einstein
Copy link
Collaborator

zombie-einstein commented May 14, 2024

I think the following is related. On jax 0.4.28 I get KeyError: 0 in _update_tqdm from tqdm_bars[0].update(arg), even when iterations begin at 0.

I've actually been meaning to file a similar issue, but I was trying to track down the root cause. I've seen this error, but then everything runs ok when I ruyn with a larger print-rate (i.e. a longer period between updates of the progress bar). Made me suspect that this was some async issue, but still need to pin down the exact circumstances.

This is with JAX 0.4.26. Also a progress bar that begins at 0.

@mdmould
Copy link
Contributor

mdmould commented May 14, 2024

I think the following is related. On jax 0.4.28 I get KeyError: 0 in _update_tqdm from tqdm_bars[0].update(arg), even when iterations begin at 0.

I've actually been meaning to file a similar issue, but I was trying to track down the root cause. I've seen this error, but then everything runs ok when I ruyn with a larger print-rate (i.e. a longer period between updates of the progress bar). Made me suspect that this was some async issue, but still need to pin down the exact circumstances.

This is with JAX 0.4.26. Also a progress bar that begins at 0.

I have found the exact same thing. But I could not consistently reproduce the error / successful runs. For example, I use scan_tqdm around steps in a training loop and changing the number of steps or even batch size leads to the error. So I also suspect it's an async issue, and progress bar updates are being attempted before the progress bar has been created.

@zombie-einstein
Copy link
Collaborator

I think the following is related. On jax 0.4.28 I get KeyError: 0 in _update_tqdm from tqdm_bars[0].update(arg), even when iterations begin at 0.

I've actually been meaning to file a similar issue, but I was trying to track down the root cause. I've seen this error, but then everything runs ok when I ruyn with a larger print-rate (i.e. a longer period between updates of the progress bar). Made me suspect that this was some async issue, but still need to pin down the exact circumstances.
This is with JAX 0.4.26. Also a progress bar that begins at 0.

I have found the exact same thing. But I could not consistently reproduce the error / successful runs. For example, I use scan_tqdm around steps in a training loop and changing the number of steps or even batch size leads to the error. So I also suspect it's an async issue, and progress bar updates are being attempted before the progress bar has been created.

Ah thanks for the info, was having a similar thought! Wonder if there is some way to flush the update at the first step?

@zombie-einstein
Copy link
Collaborator

I wonder if the ordered flag hereis what we are looking for?

@mdmould
Copy link
Contributor

mdmould commented May 14, 2024

I wonder if the ordered flag hereis what we are looking for?

Good spot! I just added ordered=True to all calls of callback and my previous failure cases seem to be working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants