Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make subsample faster in CPU #865

Merged
merged 6 commits into from
Jan 17, 2021
Merged

Make subsample faster in CPU #865

merged 6 commits into from
Jan 17, 2021

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jan 6, 2021

This PR is motivated by @insperatum in a forum topic and Jake's answer here on the differences between PyTorch CPU and jax implementation.

Benchmarks

import numpyro
from jax import random, jit

size = 100000
subsample_size = 1000

@jit
def subsample_fn(rng_key):
    return numpyro.primitives._subsample_fn(size, subsample_size, rng_key)

key0, key1 = random.PRNGKey(0), random.PRNGKey(1)
x = subsample_fn(key0).copy()
%time x = subsample_fn(key1).copy()

returns

CPU times: user 262 µs, sys: 52 µs, total: 314 µs
Wall time: 253 µs

while in PyTorch %time y = torch.randperm(size)[:subsample_size], it took

CPU times: user 26.5 ms, sys: 1.37 ms, total: 27.9 ms
Wall time: 2.51 ms

and in previous implementation, it took

CPU times: user 60.2 ms, sys: 269 µs, total: 60.4 ms
Wall time: 54.1 ms

The reason for high performance comparing to PyTorch is in PyTorch, we took a permutation of full size first, then collecting a subset. Here, we only take a permutation of size subsample_size. This is observed by @fritzo at this discussion.

fritzo
fritzo previously approved these changes Jan 7, 2021
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! We'll have to port this to Pyro if anyone complains about slow subsampling speed there.

@fritzo
Copy link
Member

fritzo commented Jan 7, 2021

Do you need to ensure double precision so that random.uniform will sample sets larger than 2**24 ~ 16million?

@fehiepsi
Copy link
Member Author

fehiepsi commented Jan 7, 2021

Thanks for reviewing and pointing that issue out, @fritzo! I have switched to randint instead of using uniform. It is a bit slower 451 µs ± 2.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) but is still pretty fast and does not suffer from the precision issue. In JAX, precision is only decided one time, at the beginning of a program, so we can't switch between two modes in the implementation.

@fehiepsi fehiepsi modified the milestones: 0.5.1, 0.5 Jan 16, 2021
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I lost track of this!

@fritzo fritzo merged commit 4f0f499 into pyro-ppl:master Jan 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants