-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make subsample faster in CPU #865
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! We'll have to port this to Pyro if anyone complains about slow subsampling speed there.
Do you need to ensure double precision so that |
Thanks for reviewing and pointing that issue out, @fritzo! I have switched to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I lost track of this!
This PR is motivated by @insperatum in a forum topic and Jake's answer here on the differences between PyTorch CPU and jax implementation.
Benchmarks
returns
while in PyTorch
%time y = torch.randperm(size)[:subsample_size]
, it tookand in previous implementation, it took
The reason for high performance comparing to PyTorch is in PyTorch, we took a permutation of full size first, then collecting a subset. Here, we only take a permutation of size
subsample_size
. This is observed by @fritzo at this discussion.