Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a current work around for multiple devices? #31

Open
jwtkeeble opened this issue Oct 28, 2024 · 3 comments
Open

Is there a current work around for multiple devices? #31

jwtkeeble opened this issue Oct 28, 2024 · 3 comments

Comments

@jwtkeeble
Copy link

Hi all,

Firstly, I just want to say that this package is great! I was just wondering if there exists a current work around when using @scan_tqdm within jax.experimental.shard_map I get the following error,

ValueError: The following ordered effects are not supported for more than 1 device: [<jax._src.debugging.OrderedDebugEffect object at 0x76972641ced0>]

Is there a current work around to display a progress bar for each device or is this out-of-scope for the meanwhile? Thanks!

@zombie-einstein
Copy link
Collaborator

zombie-einstein commented Oct 29, 2024

Do you have a simple code example? i don't think there is an immediate work around (especially since we somewhat rely on having ordered debug events) but it would be interesting to see what is going on.

@jwtkeeble
Copy link
Author

I can work on a simple code example to reproduce this error, and I'll share it down once it's written.

@jwtkeeble
Copy link
Author

So, I've managed to get a code snippet that works for a single device (but not for multiple when run over jax.experimental.shard_map), it doesn't give that exact same error as stated above (but I do the the same ordering related error).

Also, I'm relatively new to JAX, so it won't be the most proficient code, but it does reproduce the error.

The reproducible script can be found below,

import os
os.environ["XLA_FLAGS"] = (
    '--xla_force_host_platform_device_count=4 ' # simulate with 4 CPUs
)

import jax
import jax.numpy as jnp
from jax import lax
from functools import partial

from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.experimental import mesh_utils

from jax_tqdm import scan_tqdm
  
n_devices = jax.device_count()
devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(devices, axis_names=('i',))

map_axis_name = mesh.axis_names[0]
Pspec = P(map_axis_name)
Pnone = P()

n = 10_000

@scan_tqdm(n) # NOTE: Removing this decorator (for the `shard_map` case will work)
@jax.jit
def step(carry, x):
    carry = carry + 1
    return carry, carry

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
print('last_number: ',last_number)
print('all_numbers: ',all_numbers)

# NOTE: Works up to here (single jax.lax.scan is OK)

def scan_step(lower, upper):
  @jax.jit
  @partial(jax.pmap, in_axes=(0,0), out_axes=(0,0))
  def fun(lower, upper):
    return lax.scan(step, lower, upper)
  return fun(lower, upper)

shard_step = shard_map(scan_step,
                       mesh=mesh,
                       in_specs=(Pspec, Pspec),
                       out_specs=(Pspec, Pspec),
                       check_rep=False)

shard_n = jnp.stack([jnp.arange(n) for _ in range(n_devices)], axis=0) # .reshape(4,-1)
zeros = jnp.zeros(4,)

shard_last_number, shard_all_numbers = shard_step(zeros, shard_n) # NOTE: Crashes here
print(shard_last_number, shard_all_numbers)

This will output (when using the @scan_tqdm decorator),

Running for 10,000 iterations: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 8858086.59it/s]
last_number:  10000
all_numbers:  [    1     2     3 ...  9998  9999 10000]
Traceback (most recent call last):
  File "~/shard_scan_example.py", line 57, in <module>
    shard_last_number, shard_all_numbers = shard_step(zeros, shard_n) # NOTE: Crashes here
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/shard_scan_example.py", line 46, in scan_step
    return fun(lower, upper)
           ^^^^^^^^^^^^^^^^^
ValueError: Ordered effects not supported for map primitives: [<jax._src.debugging.OrderedDebugEffect object at 0x7007736f3b50>]
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

and without the @scan_tqdm decorator,

last_number:  10000
all_numbers:  [    1     2     3 ...  9998  9999 10000]
[10000. 10000. 10000. 10000.] [[1.000e+00 2.000e+00 3.000e+00 ... 9.998e+03 9.999e+03 1.000e+04]
 [1.000e+00 2.000e+00 3.000e+00 ... 9.998e+03 9.999e+03 1.000e+04]
 [1.000e+00 2.000e+00 3.000e+00 ... 9.998e+03 9.999e+03 1.000e+04]
 [1.000e+00 2.000e+00 3.000e+00 ... 9.998e+03 9.999e+03 1.000e+04]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants