Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
One line description
Use topk instead of sort for topp/topk calculation under certain conditions (scalar value of p and k).
Details
Instead of using
k
for topk, we use_padded_k
, which is strictly larger than k and monotonically non decreasing.We need/use
_padded_k > k
for cases where the smallest value of the topk=k values has some values beyond k, (for example for [9,8,8,8,7,7,7], with k=3, we have [9,8,8,8], which is 4 instead of 3 values),To prevent excessive recompilations, anytime we require an expansion of
_padded_k
we increment with a fixed constant_increment
(usually >1), to have a bucketed approach to prevent multiple shapesBasic outline
_padded_k
num_duplicates_of_smallest_of_topk
max_num_duplicates_of_smallest_of_topk
_padded_k
is big enough to containmax_num_duplicates_of_smallest_of_topk
. if not, then expand_padded_k
, and redo the topk again with expanded_padded_k
_padded_k
Perf benefit
The feature gives a 49% increase in thruput in the case with warmup, and 30% increase in thruput in the case without warmup
Extra Notes
_init_sampling_tensors
to determine if its scalar case. This has a minor perf hit. ideally if someone could tell us that its a scalar from the top itself...fliplr
in the code, which could be removed, if we can compute reverse cumsum. however the formula for reverse cumsum as expressed here ,x + torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1)
is numerically unstable, because of the addition/subtraction. It works well enough on ints and large numbers, but not on small probability values.k
affects the gains we might get from this. For example in the expt shown above, with k=20, thruput increases around 30%. But if k = 2000, instead of 20, throughput increases the gain is 14%. Thus the gain % might decrease with increasing k, as asymptotically topk would probably converges to sort's performance for large k. However practically k is pretty small.or