Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: fix the performance issue of append_paged_kv_cache #588

Merged
merged 6 commits into from
Nov 6, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 6, 2024

The performance of append_paged_kv_cache is terrible for small batch size, which is a known issue that we haven't fixed for a long time, this PR fixes it. This PR also adds support for non-contiguous append keys/values (which could be sliced from fused qkv matrix).

We first call a triton kernel to convert append_indptr to batch_indices and positions (which is similar to CSR2COO conversion in sparse matrix). After the conversion, we can use element parallelism instead of batch parallelism.

It's also worth trying using triton for the second AppendPagedKVCacheKernel kernel, I think the performance should be fine. I'll leave it for future work.

Some todo items:

  1. add torch.compile support.

After this PR (reference number can be found at #583 ):

model: l1b      seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.094ms throughput:    5.563GB/s
model: l1b      seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.014ms all_layers:   0.216ms throughput: 1514.280GB/s
model: l1b      seqlens: [5000]                                   single_layer: 0.014ms all_layers:   0.216ms throughput: 1517.017GB/s
model: l1b      seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.014ms all_layers:   0.217ms throughput: 1510.863GB/s
---
model: l3b      seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.165ms throughput:   11.123GB/s
model: l3b      seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.021ms all_layers:   0.580ms throughput: 1975.732GB/s
model: l3b      seqlens: [5000]                                   single_layer: 0.021ms all_layers:   0.586ms throughput: 1958.078GB/s
model: l3b      seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.021ms all_layers:   0.581ms throughput: 1973.174GB/s
---
model: l8b      seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.185ms throughput:   11.321GB/s
model: l8b      seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.021ms all_layers:   0.661ms throughput: 1982.815GB/s
model: l8b      seqlens: [5000]                                   single_layer: 0.021ms all_layers:   0.662ms throughput: 1980.227GB/s
model: l8b      seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.021ms all_layers:   0.667ms throughput: 1964.861GB/s
---
model: l70b-tp8 seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.457ms throughput:    1.434GB/s
model: l70b-tp8 seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.009ms all_layers:   0.710ms throughput:  576.866GB/s
model: l70b-tp8 seqlens: [5000]                                   single_layer: 0.009ms all_layers:   0.685ms throughput:  598.366GB/s
model: l70b-tp8 seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.009ms all_layers:   0.690ms throughput:  593.453GB/s

cc @abcdabcd987

@yzh119 yzh119 merged commit e15f7c9 into main Nov 6, 2024
@abcdabcd987
Copy link
Member

Awesome! Huge improvements! Thanks! 🙏

@yzh119 yzh119 deleted the accelerate-append-kv-cache branch November 10, 2024 08:47
yzh119 added a commit that referenced this pull request Dec 17, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.2.0](v0.1.6...v0.2.0)
(2024-12-17)

[Release
Blog](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html).

### Features

* add `rotary_dim` argument to rope APIs for partial apply rope
([#599](#599))
([eb9bc71](eb9bc71))
* add a `use_softmax` field in variant class
([#533](#533))
([d81af97](d81af97))
* add an option `non_blocking` to plan function
([#622](#622))
([560af6f](560af6f))
* add gemma_rmsnorm and gemma_fused_add_rmsnorm
([#477](#477))
([1a6b17e](1a6b17e))
* add group size 3 to GQA decode dispatch
([#558](#558))
([6227562](6227562))
* add JIT compilation support for FA3 templates
([#672](#672))
([d4e8d79](d4e8d79))
* allow the cascade kernels to be executed using varying sequence
lenghts ([#627](#627))
([92ac440](92ac440))
* CUDAGraph compatibility of multi-level cascade inference APIs
([#586](#586))
([2332e8a](2332e8a))
* fix the maximal grid dimension in prefill planning with CUDA graphs
([#639](#639))
([86ca89a](86ca89a))
* improve the precision of the FusedAddRMSNormKernel function
([#587](#587))
([c7dc921](c7dc921))
* JIT compilation
([#507](#507))
([3613a5b](3613a5b))
* modify group-gemm stage number
([#497](#497))
([52dab1d](52dab1d))
* non-contiguous query with paged kv cache
([#553](#553))
([89f2c4a](89f2c4a))
* pass a dynamic token count to the cascade kernels
([#635](#635))
([5fe9f7d](5fe9f7d))
* simplify prefill JIT compilation
([#605](#605))
([fe4f898](fe4f898))
* specify gemm backend
([#648](#648))
([0cc1a51](0cc1a51))
* support cached cos/sin in rope APIs
([#585](#585))
([83e541d](83e541d))
* support huggingface transformer style rope interface
([#568](#568))
([4f40420](4f40420))
* support sm90 cutlass group gemm
([#509](#509))
([794bdda](794bdda))
* torch custom_op fix for rope
([#569](#569))
([3e104bc](3e104bc))
* torch custom_op support: norm
([#552](#552))
([f6e0010](f6e0010))
* torch.compile and custom_op support
([#554](#554))
([9bf916f](9bf916f))
* warmup for jit kernel tests
([#629](#629))
([8f5f349](8f5f349))


### Bug Fixes

* AOT compiler flags on non-sm90
([#522](#522))
([0aa4726](0aa4726))
* batch decode kernel redundant store output to gmem
([#505](#505))
([90e42a7](90e42a7))
* compatible with torch 2.2
([#478](#478))
([ac41d1b](ac41d1b))
* #452
([b53a46f](b53a46f))
* remove redundant load
([#495](#495))
([2de16b0](2de16b0))
* update bmm fp8 test
([#487](#487))
([45eac04](45eac04))


### Performance Improvements

* accelerate JIT compilation speed
([#618](#618))
([eaf73fd](eaf73fd))
* Dense and sparse customizable flashattention-3 template
([#667](#667))
([51236c9](51236c9))
* fix prefill kernel performance degradation (step 1)
([#602](#602))
([595cf60](595cf60))
* fix the performance issue of `append_paged_kv_cache`
([#588](#588))
([e15f7c9](e15f7c9))
* improve parallelism in RoPE with pos_ids
([#609](#609))
([ff05155](ff05155))
* improve plan performance by using non-blocking memcpy
([#547](#547))
([41ebe6d](41ebe6d))
* reduce the read and write of shared memory in the
FusedAddRMSNormKernel
([#592](#592))
([2043ca2](2043ca2))
* reduce total_num_tiles_q by one
([#644](#644))
([553ace5](553ace5))
* remove unnecessary contiguous operation in block sparse attention
([#561](#561))
([7a7ad46](7a7ad46))
* speedup jit compilation of prefill attention kernels
([#632](#632))
([a059586](a059586))
* use cuda-core implemention for io-bound block-sparse attention
([#560](#560))
([3fbf028](3fbf028))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants