Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Contrib][Sort] Faster Top-K Implementation #13599

Merged
merged 5 commits into from
Jan 4, 2023

Conversation

AndrewZhaoLuo
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo commented Dec 12, 2022

Summary:

This is a simple rewrite of hand-coded top-k function used for CPU targets.

The old implementation sorted each axis and then took the biggest k elements.

The new implementation does a single pass of each axis, keeping a min heap to store the top-k elements up to that point.

If n is the size of the array, and we want to find top k, the old implementation has runtime in O(nlogn) with additional memory O(n) to store the sorted array. The new implementation is O(n log k), and in practice is probably amortized to O(n / k * log k) in many scenarios and only requires O(k). Note n >> k most of the time.

In practice this new kernel led to a 20x speedup over existing one. On a Xeon Platinum 8370C CPU @ 2.80GHz for input shape [1, 3050] with k = 15, the latency went from 200us --> ~10us. There is probably more room for shaving off a little more time on the scale of a single us's, however I have determined it to not be worth it.

This change however is probably in the range of worth committing.

I've launched benchmarks on my m1 mac, and a Xeon Platinum 8370C CPU @ 2.80GHz with 8 cores.

Data:
All data is collected along axis=1.

M1:

input_shape k onnx_ms tvm_new_ms tvm_old_ms tvm_speedup
(1, 128) 1 0.004760980606 0.00318288 0.00587965 0.8472735384
(1, 128) 4 0.004210948944 0.00353794 0.00526537 0.4882587042
(1, 128) 8 0.00492811203 0.00354373 0.00551627 0.5566281856
(1, 128) 16 0.00457406044 0.00390078 0.00545921 0.3995175324
(1, 128) 32 0.004778146744 0.00437091 0.00566372 0.2957759368
(1, 128) 64 0.004880905151 0.00484171 0.00639961 0.3217664833
(1, 512) 1 0.006465196609 0.00349669 0.01804169 4.159648124
(1, 512) 4 0.004405975342 0.00500128 0.01633621 2.2664058
(1, 512) 8 0.004703998566 0.0041571 0.01799835 3.329544634
(1, 512) 16 0.004922866821 0.00521708 0.01612082 2.090008204
(1, 512) 32 0.005817890167 0.00535875 0.01813752 2.384655003
(1, 512) 64 0.005837917328 0.00758326 0.0165729 1.185458497
(1, 1024) 1 0.007965803146 0.00402747 0.04586543 10.38814938
(1, 1024) 4 0.004651069641 0.00418419 0.03339498 6.981229342
(1, 1024) 8 0.004925966263 0.00467966 0.03454169 6.381239235
(1, 1024) 16 0.005457162857 0.00535332 0.03459499 5.462342995
(1, 1024) 32 0.006209135056 0.00663129 0.03485203 4.255693839
(1, 1024) 64 0.006520032883 0.0090463 0.03381209 2.73767065
(128, 4096) 1 0.5929992199 0.51040627 28.53409492 54.90467163
(128, 4096) 4 0.1157960892 0.53430751 28.58584715 52.50073996
(128, 4096) 8 0.08951592445 0.65343042 27.99126077 41.83740076
(128, 4096) 16 0.1294538975 0.90574999 28.09520573 30.01872044
(128, 4096) 32 0.3273978233 1.39101165 29.52651913 20.22665121
(128, 4096) 64 0.657585144 3.40057753 28.76851003 7.459889468
(128, 128) 1 0.05363202095 0.02861457 0.49791802 16.40085628
(128, 128) 4 0.02515816689 0.05921382 0.48185504 7.137543567
(128, 128) 8 0.05043721199 0.10586789 0.51938795 3.906000771
(128, 128) 16 0.0539188385 0.18564539 0.49861457 1.685844071
(128, 128) 32 0.05814409256 0.31484302 0.48907493 0.5533929575
(128, 128) 64 0.09320878983 0.44958459 0.49063965 0.09131776514
(128, 512) 1 0.2344889641 0.06842087 2.6640458 37.93615793
(128, 512) 4 0.03355693817 0.11402793 2.64807838 22.22306807
(128, 512) 8 0.02771997452 0.19101078 2.65586672 12.90427661
(128, 512) 16 0.04784107208 0.33276708 2.67728463 7.0455213
(128, 512) 32 0.1548981667 0.6094633 2.67251341 3.38502763
(128, 512) 64 0.3678889275 1.03113543 2.65719582 1.57696103
(128, 1024) 1 0.4660680294 0.12275126 5.66230752 45.12830467
(128, 1024) 4 0.02821993828 0.17523665 5.76842294 31.91790239
(128, 1024) 8 0.03698420525 0.27126284 5.72688211 20.11193007
(128, 1024) 16 0.08447217941 0.45277246 5.7195709 11.63232949
(128, 1024) 32 0.141343832 0.78713792 5.69631002 6.236736886
(128, 1024) 64 0.3908679485 1.38000749 5.70882461 3.136806975
(128, 4096) 1 0.5177731514 0.45661204 29.01612208 62.54655493
(128, 4096) 4 0.07772278786 0.52315133 28.11686041 52.7451762
(128, 4096) 8 0.08546304703 0.64544875 28.07437255 42.49589731
(128, 4096) 16 0.1305837631 0.89367784 28.22949792 30.58800258
(128, 4096) 32 0.2483069897 1.3729229 28.26283543 19.58588682
(128, 4096) 64 0.5679419041 2.22060246 28.05246176 11.63281576
(1, 128, 128) 1 0.06175613403 0.02783994 0.48192299 16.31048953
(1, 128, 128) 4 0.02490067482 0.05053001 0.4883058 8.663679069
(1, 128, 128) 8 0.04355311394 0.13984835 0.50664956 2.622849751
(1, 128, 128) 16 0.1872887611 0.18108045 0.50739745 1.802055385
(1, 128, 128) 32 0.3335990906 0.31482086 0.49584384 0.5750031304
(1, 128, 128) 64 0.306704998 0.45216494 0.50409458 0.1148466752
(1, 512, 128) 1 0.2505269051 0.09281251 2.68397045 27.91819702
(1, 512, 128) 4 0.07675075531 0.13602672 2.69848247 18.83788531
(1, 512, 128) 8 0.1162371635 0.20078161 2.69511252 12.42310444
(1, 512, 128) 16 0.2757160664 0.36553754 2.73840629 6.491450235
(1, 512, 128) 32 0.5648989677 0.60555339 2.70246499 3.462802182
(1, 512, 128) 64 1.145206928 1.05855501 2.70480376 1.555184884
(1, 1024, 128) 1 0.4711620808 0.17656796 5.79110494 31.79816417
(1, 1024, 128) 4 0.1256010532 0.217788 5.83131579 25.77519326
(1, 1024, 128) 8 0.1894469261 0.29612498 5.80318294 18.5970733
(1, 1024, 128) 16 0.354445219 0.46999371 5.7760871 11.28971149
(1, 1024, 128) 32 0.7124738693 0.82430171 5.78476507 6.017776379
(1, 1024, 128) 64 1.514050007 1.40707462 5.73287128 3.074319299
(1, 4096, 128) 1 2.008599997 0.63573374 28.47925624 43.79745914
(1, 4096, 128) 4 0.4299049377 0.70525291 28.16252006 38.93251167
(1, 4096, 128) 8 0.4861471653 0.81754459 28.17799444 33.46661477
(1, 4096, 128) 16 0.6892678738 1.03458623 28.50899588 26.55593981
(1, 4096, 128) 32 1.158033848 1.52110841 28.35097251 17.63836418
(1, 4096, 128) 64 2.10275507 2.37215047 28.24873166 10.90849064

Xeon:

input_shape k onnx_ms tvm_new_ms tvm_old_ms tvm_speedup
(1, 128) 1 0.004903078079 0.00410365 0.00624143 0.5209459871
(1, 128) 4 0.005003929138 0.0043396 0.00646044 0.4887178542
(1, 128) 8 0.005117177963 0.00446607 0.00637956 0.4284505169
(1, 128) 16 0.005379199982 0.00552826 0.00628157 0.1362652987
(1, 128) 32 0.005692481995 0.0067702 0.00636218 -0.06026705267
(1, 128) 64 0.005892515182 0.00750643 0.00641667 -0.1451768684
(1, 512) 1 0.005101680756 0.0045848 0.01703391 2.715300558
(1, 512) 4 0.005270719528 0.00504706 0.01705142 2.378485693
(1, 512) 8 0.005369901657 0.00546179 0.01737532 2.181250103
(1, 512) 16 0.005724430084 0.00659026 0.01734728 1.632260336
(1, 512) 32 0.00634932518 0.00810608 0.01715245 1.115998115
(1, 512) 64 0.007574081421 0.0123074 0.01790654 0.4549409298
(1, 1024) 1 0.005491256714 0.00518735 0.04105347 6.914150771
(1, 1024) 4 0.005714654922 0.00605234 0.04116495 5.801493307
(1, 1024) 8 0.00586771965 0.0061588 0.04073779 5.614566149
(1, 1024) 16 0.006246089935 0.00775657 0.04099463 4.285149235
(1, 1024) 32 0.006800889969 0.00981186 0.04086617 3.164976875
(1, 1024) 64 0.008088588715 0.01391958 0.04071081 1.924715401
(128, 4096) 1 0.1717588902 0.64934249 34.07038682 51.469055
(128, 4096) 4 0.1751728058 0.77309782 34.17276327 43.20237955
(128, 4096) 8 0.2098040581 0.96305912 34.05676189 34.36310615
(128, 4096) 16 0.2812550068 1.33817001 34.37907734 24.6911133
(128, 4096) 32 0.4438226223 2.10236207 34.31240895 15.3208847
(128, 4096) 64 0.7761180401 3.30002793 34.36870606 9.414671266
(128, 128) 1 0.01773786545 0.04629403 0.63880666 12.79889934
(128, 128) 4 0.03143501282 0.10687754 0.63899949 4.978800504
(128, 128) 8 0.07985520363 0.17921997 0.6422134 2.583380803
(128, 128) 16 0.08305168152 0.32098773 0.63938078 0.9919165758
(128, 128) 32 0.1176977158 0.52143936 0.64236284 0.2319032457
(128, 128) 64 0.1545715332 0.75031197 0.64036611 -0.1465335279
(128, 512) 1 0.04941511154 0.10803366 3.23408057 28.93586045
(128, 512) 4 0.03600502014 0.19161847 3.24270817 15.92273281
(128, 512) 8 0.05524492264 0.31474983 3.2444747 9.308106282
(128, 512) 16 0.1011757851 0.54488891 3.26249455 4.987448983
(128, 512) 32 0.2080805302 0.95523322 3.24779573 2.400002912
(128, 512) 64 0.3946566582 1.59471235 3.26046632 1.044548235
(128, 1024) 1 0.08969402313 0.18861577 7.19580053 37.15057739
(128, 1024) 4 0.05647706985 0.28350667 7.21739183 24.45757329
(128, 1024) 8 0.08061361313 0.4269163 7.28504533 16.06434102
(128, 1024) 16 0.135792017 0.7080551 7.1932555 9.159174759
(128, 1024) 32 0.2607634068 1.21888382 7.17930309 4.890063493
(128, 1024) 64 0.4947025776 2.06306489 7.17537368 2.478016477
(128, 4096) 1 0.1604876518 0.64992401 34.3822333 51.90192818
(128, 4096) 4 0.1740963459 0.7714043 34.10499986 43.21157603
(128, 4096) 8 0.2261743546 0.95917291 34.14983401 34.60341796
(128, 4096) 16 0.3008635044 1.33355577 34.3734485 24.77578627
(128, 4096) 32 0.4579799175 2.03814355 34.40618558 15.8811395
(128, 4096) 64 0.780279398 3.28102778 34.1401224 9.405313423
(1, 128, 128) 1 0.01952576637 0.0462761 0.64042845 12.83929177
(1, 128, 128) 4 0.0309150219 0.10419626 0.63757635 5.118994578
(1, 128, 128) 8 0.08383250237 0.186736 0.63514115 2.401278543
(1, 128, 128) 16 0.1831464767 0.33090011 0.63958008 0.9328494028
(1, 128, 128) 32 0.3689568043 0.52617823 0.64795514 0.231436618
(1, 128, 128) 64 0.4878430367 0.78079381 0.66757144 -0.1450093079
(1, 512, 128) 1 0.05654454231 0.11988411 3.2500662 26.11006655
(1, 512, 128) 4 0.08462810516 0.20252434 3.22769785 14.93733301
(1, 512, 128) 8 0.1554389 0.31994938 3.21968138 9.063096169
(1, 512, 128) 16 0.3183102608 0.56696191 3.24684763 4.726747375
(1, 512, 128) 32 0.6605670452 0.97589354 3.25511664 2.335524324
(1, 512, 128) 64 1.274202108 1.62597913 3.25872819 1.004163602
(1, 1024, 128) 1 0.1024491787 0.21505164 7.20831624 32.51900148
(1, 1024, 128) 4 0.146468401 0.31597621 7.20327928 21.79690386
(1, 1024, 128) 8 0.230694294 0.46149483 7.20596765 14.61440602
(1, 1024, 128) 16 0.4166545868 0.74174002 7.19909731 8.705688133
(1, 1024, 128) 32 0.8144786358 1.24213877 7.20898491 4.803687224
(1, 1024, 128) 64 1.615594864 2.11381521 7.20788828 2.409895172
(1, 4096, 128) 1 1.501801014 1.53394805 35.07190699 21.86381667
(1, 4096, 128) 4 1.525020361 1.63445534 34.68660191 20.22211667
(1, 4096, 128) 8 1.592685223 1.81754422 34.75576938 18.12237898
(1, 4096, 128) 16 1.766026497 2.14398094 34.77502366 15.21983806
(1, 4096, 128) 32 2.223840714 2.81438604 34.83320587 11.37684005
(1, 4096, 128) 64 3.18167758 4.0325292 34.87056221 7.647317969

As can be seen, except in one pathological case (k ~ axis_size), we see significant speedups across almost all conditions. For M1, this case also has speedups surprisingly.

Other Changes:

  • Consolidate fp16 comparison functions by overloading struct with comparison operators
  • Add stable sorting option for comparison functions

@tvm-bot
Copy link
Collaborator

tvm-bot commented Dec 12, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: contrib, sort See #10317 for details

Generated by tvm-bot

@AndrewZhaoLuo AndrewZhaoLuo changed the title [Contrib] Faster Top-K Impl. [Contrib] Faster Top-K Implementation Dec 12, 2022
@AndrewZhaoLuo AndrewZhaoLuo marked this pull request as ready for review December 13, 2022 22:03
@AndrewZhaoLuo AndrewZhaoLuo changed the title [Contrib] Faster Top-K Implementation [Contrib][Sort] Faster Top-K Implementation Dec 13, 2022
Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numbers look excellent! We could probably simplify the implementation to use std::partial_sort (https://en.cppreference.com/w/cpp/algorithm/partial_sort), but that can wait for a future PR.

@tkonolige tkonolige merged commit 231882a into apache:main Jan 4, 2023
fzi-peccia pushed a commit to fzi-peccia/tvm that referenced this pull request Mar 27, 2023
This is a simple rewrite of hand-coded top-k function used for CPU targets.

The old implementation sorted each axis and then took the biggest k elements.

The new implementation does a single pass of each axis, keeping a min heap to store the top-k elements up to that point.

If n is the size of the array, and we want to find top k, the old implementation has runtime in O(nlogn) with additional memory O(n) to store the sorted array. The new implementation is O(n log k), and in practice is probably amortized to O(n / k * log k) in many scenarios and only requires O(k). Note n >> k most of the time.

In practice this new kernel led to a 20x speedup over existing one. On a Xeon Platinum 8370C CPU @ 2.80GHz for input shape [1, 3050] with k = 15, the latency went from 200us --> ~10us. There is probably more room for shaving off a little more time on the scale of a single us's, however I have determined it to not be worth it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants