-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Permutation is slow #5328
Comments
torch and jax use different approaches to random permutation. Torch uses a Fisher-Yates algorithm on CPU (see https://github.com/pytorch/pytorch/blob/4883d39c6fd38431bdc60e1db6402e251429b1e1/aten/src/ATen/native/TensorFactories.cpp#L674-L696) JAX avoids Fisher-Yates because it is costly to compute on immutable arrays and performs poorly on parallel architectures. JAX instead uses multiple sorts; see the source for details: Trying this on GPU for a large array shows that JAX's approach is slightly faster than torch's in this context: import torch
%timeit torch.randperm(50000)[:500]
# 1000 loops, best of 3: 803 µs per loop
import jax
def f(i):
return jax.random.permutation(jax.random.PRNGKey(i), 50000)[:500]
jf = jax.jit(f)
jf(0).block_until_ready()
%timeit x = jf(0).block_until_ready()
# 1000 loops, best of 3: 657 µs per loop In short, this algorithm has been chosen and tuned to be fast in particular domains, and as a result it may be sub-optimal in other domains. What do you think? |
@jakevdp How about switching between those algorithms based on the current backend platform? I did some tests and found that Fisher-Yates is pretty fast on CPU and pretty slow on GPU. |
One issue there is that it would materially change the results of the operation when you change backends. Unless I'm mistaken, I think JAX attemps to keep the result of executing code consistent across backends (modulo unavoidable things like floating point precision). |
Understood! That makes sense to me. |
I was curious about the relative speed of pytorch and JAX on GPU; it turns out that pytorch uses a single sort to permute arrays with more than 30000 elements on GPU: https://github.com/pytorch/pytorch/blob/d5a971e193c1f8ab83861c3ea258ddeb57d89f0c/aten/src/ATen/native/cuda/TensorFactories.cu#L79-L135 This is faster than JAX's multiple sorts, but due to the potential for key collisions, it will produce biased permutations. Assuming 32-bit keys, a length 30000 array has approximately 10% probability of a key collision. |
I'm going to close this, because given the algorithmic constraints and JAX's goal of producing the same results on all backends, I don't think there's any improvement to be had. |
Hi awesome Jax people,
The
jax.random.permutation
function seems to be a couple of orders of magnitude slower than the pytorch equivalent. (This is the cause of an issue downstream in numpyro, see discussion with @fehiepsi)Pytorch profiling
Jax profiling
Happy new year!
~ Luke
The text was updated successfully, but these errors were encountered: