Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permutation is slow #5328

Closed
insperatum opened this issue Jan 6, 2021 · 6 comments
Closed

Permutation is slow #5328

insperatum opened this issue Jan 6, 2021 · 6 comments
Assignees

Comments

@insperatum
Copy link

Hi awesome Jax people,

The jax.random.permutation function seems to be a couple of orders of magnitude slower than the pytorch equivalent. (This is the cause of an issue downstream in numpyro, see discussion with @fehiepsi)

Pytorch profiling

%%timeit
import torch
x = torch.randperm(10000)[:500]

# 79.4 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Jax profiling

import jax

def f(i):
    return jax.random.permutation(jax.random.PRNGKey(i), 10000)[:500]

jf = jax.jit(f)
%timeit x = jf(0).copy()

# 5.38 ms ± 300 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Happy new year!
~ Luke

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 6, 2021

torch and jax use different approaches to random permutation.

Torch uses a Fisher-Yates algorithm on CPU (see https://github.com/pytorch/pytorch/blob/4883d39c6fd38431bdc60e1db6402e251429b1e1/aten/src/ATen/native/TensorFactories.cpp#L674-L696)

JAX avoids Fisher-Yates because it is costly to compute on immutable arrays and performs poorly on parallel architectures. JAX instead uses multiple sorts; see the source for details:
https://github.com/google/jax/blob/81990b40476087d5c3c2aab8d60aa6b921bb0b61/jax/_src/random.py#L510-L525

Trying this on GPU for a large array shows that JAX's approach is slightly faster than torch's in this context:

import torch
%timeit torch.randperm(50000)[:500]
# 1000 loops, best of 3: 803 µs per loop

import jax

def f(i):
    return jax.random.permutation(jax.random.PRNGKey(i), 50000)[:500]

jf = jax.jit(f)
jf(0).block_until_ready()
%timeit x = jf(0).block_until_ready()
# 1000 loops, best of 3: 657 µs per loop

In short, this algorithm has been chosen and tuned to be fast in particular domains, and as a result it may be sub-optimal in other domains. What do you think?

@fehiepsi
Copy link
Contributor

fehiepsi commented Jan 6, 2021

@jakevdp How about switching between those algorithms based on the current backend platform? I did some tests and found that Fisher-Yates is pretty fast on CPU and pretty slow on GPU.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 6, 2021

How about switching between those algorithms based on the current backend platform?

One issue there is that it would materially change the results of the operation when you change backends. Unless I'm mistaken, I think JAX attemps to keep the result of executing code consistent across backends (modulo unavoidable things like floating point precision).

@fehiepsi
Copy link
Contributor

fehiepsi commented Jan 7, 2021

Understood! That makes sense to me.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 7, 2021

I was curious about the relative speed of pytorch and JAX on GPU; it turns out that pytorch uses a single sort to permute arrays with more than 30000 elements on GPU: https://github.com/pytorch/pytorch/blob/d5a971e193c1f8ab83861c3ea258ddeb57d89f0c/aten/src/ATen/native/cuda/TensorFactories.cu#L79-L135

This is faster than JAX's multiple sorts, but due to the potential for key collisions, it will produce biased permutations. Assuming 32-bit keys, a length 30000 array has approximately 10% probability of a key collision.

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 27, 2021

I'm going to close this, because given the algorithmic constraints and JAX's goal of producing the same results on all backends, I don't think there's any improvement to be had.

@jakevdp jakevdp closed this as completed Feb 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants