-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Break up
_kernels
into multiple modules (#428)
Breaks up the `_kernels` module into multiple modules to avoid issues caused by the file growing too large.
- Loading branch information
Showing
13 changed files
with
311 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Copyright (c) 2023 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <torch/extension.h> | ||
|
||
#include "flashinfer_ops_decode.h" | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, | ||
"Single-request decode with KV-Cache operator"); | ||
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m, | ||
"BatchDecodeWithPagedKVCachePyTorchWrapper") | ||
.def(py::init<unsigned int, bool, unsigned int>()) | ||
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) | ||
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) | ||
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) | ||
.def("update_page_locked_buffer_size", | ||
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) | ||
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright (c) 2023 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
#include <flashinfer/attention/handler.cuh> | ||
#include <flashinfer/layout.cuh> | ||
#include <memory> | ||
|
||
torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, | ||
torch::Tensor tmp, unsigned int pos_encoding_mode, | ||
unsigned int layout, int window_left, | ||
float logits_soft_cap, float sm_scale, float rope_scale, | ||
float rope_theta); | ||
|
||
class BatchDecodeWithPagedKVCachePyTorchWrapper { | ||
public: | ||
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, | ||
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, | ||
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, | ||
unsigned int pos_encoding_mode, float logits_soft_cap, | ||
torch::Tensor empty_q_data, torch::Tensor empty_kv_data); | ||
void EndForward(); | ||
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); | ||
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } | ||
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache, | ||
std::optional<torch::Tensor> paged_k_cache, | ||
std::optional<torch::Tensor> paged_v_cache, | ||
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, | ||
torch::Tensor paged_kv_last_page_len, | ||
unsigned int pos_encoding_mode, int window_left, | ||
float logits_soft_cap, float sm_scale, float rope_scale, | ||
float rope_theta, bool return_lse); | ||
BatchDecodeWithPagedKVCachePyTorchWrapper( | ||
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout) | ||
: handler_(handler_ptr), kv_layout_(kv_layout) {} | ||
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph, | ||
unsigned int fixed_batch_size) | ||
: kv_layout_(flashinfer::QKVLayout(layout)), | ||
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(enable_cuda_graph, | ||
fixed_batch_size)) {} | ||
|
||
protected: | ||
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_; | ||
flashinfer::QKVLayout kv_layout_; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Copyright (c) 2023 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <torch/extension.h> | ||
|
||
#include "flashinfer_ops_prefill.h" | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, | ||
"Single-request prefill with KV-Cache operator, return logsumexp"); | ||
m.def( | ||
"single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask, | ||
"Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp"); | ||
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>( | ||
m, "BatchPrefillWithPagedKVCachePyTorchWrapper") | ||
.def(py::init<unsigned int, bool>()) | ||
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) | ||
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) | ||
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) | ||
.def("update_page_locked_buffer_size", | ||
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) | ||
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward) | ||
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); | ||
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>( | ||
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") | ||
.def(py::init<unsigned int, bool>()) | ||
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward) | ||
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) | ||
.def("is_cuda_graph_enabled", | ||
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled) | ||
.def("update_page_locked_buffer_size", | ||
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) | ||
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) | ||
.def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); | ||
} |
Oops, something went wrong.