-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
perf: initial cuda graph support (#256)
As requested in #187 , this PR adds initial support of `CUDAGraph` compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs. # Proposed APIs We add another wrapper `CUDAGraphBatchDecodeWithPagedKVCacheWrapper`, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class. The behavior of `CUDAGraphBatchDecodeWithPagedKVCacheWrapper` is a little bit different from `BatchDecodeWithPagedKVCacheWrapper`'s: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes). This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing. # Examples See `test_cuda_graph_batch_decode_with_paged_kv_cache` in unittests. `begin_forward` functions should not be captured as some of the operators are not allowed to be captured. cc @AgrawalAmey @LiuXiaoxuanPKU @comaniac
- Loading branch information
Showing
12 changed files
with
710 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.