Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized topp/topk calculation #195

Merged
merged 1 commit into from
Sep 17, 2024

Conversation

ssarkar2
Copy link

@ssarkar2 ssarkar2 commented Aug 19, 2024

One line description

Use topk instead of sort for topp/topk calculation under certain conditions (scalar value of p and k).

Details

Instead of using k for topk, we use _padded_k, which is strictly larger than k and monotonically non decreasing.

We need/use _padded_k > k for cases where the smallest value of the topk=k values has some values beyond k, (for example for [9,8,8,8,7,7,7], with k=3, we have [9,8,8,8], which is 4 instead of 3 values),

To prevent excessive recompilations, anytime we require an expansion of _padded_k we increment with a fixed constant _increment (usually >1), to have a bucketed approach to prevent multiple shapes

Basic outline

  1. perform topk with _padded_k
  2. find the "kth" value in each row (smallest number that will be in topk), this is variable num_duplicates_of_smallest_of_topk
  3. find maximum of number of duplicates, this variable is max_num_duplicates_of_smallest_of_topk
  4. check if _padded_k is big enough to contain max_num_duplicates_of_smallest_of_topk. if not, then expand _padded_k, and redo the topk again with expanded _padded_k
  5. maskout the values that are extra in _padded_k
  6. move to doing topp

Perf benefit

The feature gives a 49% increase in thruput in the case with warmup, and 30% increase in thruput in the case without warmup

Extra Notes

  1. Works only for "scalar" case, though it might be possible to extend the basic idea (topk instead of sort) for vector case as well. (Outline of this is: find max k in topk vector, then perform topk using that, etc. needs some bucketing possibly to prevent dyn shapes etc)
  2. Need an additional check in _init_sampling_tensors to determine if its scalar case. This has a minor perf hit. ideally if someone could tell us that its a scalar from the top itself...
  3. Some tradeoffs can be made, where we use a sufficiently large padded_k (which is still smaller than vocab size) from the beginning, and hope that every case lands within that bucket. Cases that wont land are expected to be very, very rare. For example if padded_k = max(2 * k, 100) is used, and k = say 50, then we need the smallest of the topk value to repeat 50 times with same probability, which is exceedingly unlikely. If we trade off this mathematical improbability, then we can do with only 1 topk op, which might be faster
  4. There is a fliplr in the code, which could be removed, if we can compute reverse cumsum. however the formula for reverse cumsum as expressed here , x + torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1) is numerically unstable, because of the addition/subtraction. It works well enough on ints and large numbers, but not on small probability values.
  5. The value of k affects the gains we might get from this. For example in the expt shown above, with k=20, thruput increases around 30%. But if k = 2000, instead of 20, throughput increases the gain is 14%. Thus the gain % might decrease with increasing k, as asymptotically topk would probably converges to sort's performance for large k. However practically k is pretty small.
  6. For larger models, the gains may be less, as they are more device bound probably
  7. Cumsum may be taking long. Maybe try below. Initial try
import torch
y = torch.tensor([[1,2,3], [4,5,6]])
mask1 = torch.tensor([[[1,0,0],[1,1,0],[1,1,1]], [[1,0,0],[1,1,0],[1,1,1]]])
torch.sum(y.unsqueeze(1)*mask1,2)

or

F.conv1d(torch.tensor([[[0,0,0,0,1,2,3,4,5]], [[0,0,0,0,6,7,8,9,10.0]]]), torch.ones([1,1,5], dtype=torch.float32))

@ssarkar2 ssarkar2 force-pushed the sarkar/apply_topp_topk_scalar_opt branch 7 times, most recently from 388f62b to ef9a5ba Compare August 20, 2024 16:38
@ssarkar2 ssarkar2 force-pushed the sarkar/apply_topp_topk_scalar_opt branch from ef9a5ba to 2ab316d Compare August 27, 2024 17:01
@kzawora-intel kzawora-intel added the habana Issues or PRs submitted by Habana Labs label Aug 29, 2024
libinta pushed a commit that referenced this pull request Sep 10, 2024
@michalkuligowski michalkuligowski merged commit 4c1ca3a into habana_main Sep 17, 2024
13 checks passed
michalkuligowski pushed a commit that referenced this pull request Sep 26, 2024
Reverted PRs:
- #250 
- #195

---------

Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Jani Monoses <[email protected]>
Co-authored-by: Daniele <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: Divakar Verma <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: jiqing-feng <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: sroy745 <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Brendan Wong <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Peter Salas <[email protected]>
Co-authored-by: Alex Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Co-authored-by: Hanzhi Zhou <[email protected]>
Co-authored-by: Kunshang Ji <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
habana Issues or PRs submitted by Habana Labs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants