diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 130f4abb..94365376 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -15,7 +15,7 @@ */ #include -#include "flashinfer_ops.h" +#include "flashinfer_ops_decode.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index d54bddff..51494150 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -15,7 +15,7 @@ */ #include -#include "flashinfer_ops.h" +#include "flashinfer_ops_prefill.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 49c0f518..51dbb2b8 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -18,13 +18,6 @@ #include "flashinfer_ops.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, - "Single-request decode with KV-Cache operator"); - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill with KV-Cache operator, return logsumexp"); - m.def( - "single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask, - "Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp"); m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); m.def("merge_state", &merge_state, "Merge two self-attention states"); m.def("merge_state_in_place", &merge_state_in_place, @@ -50,36 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - py::class_(m, - "BatchDecodeWithPagedKVCachePyTorchWrapper") - .def(py::init()) - .def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) - .def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) - .def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) - .def("update_page_locked_buffer_size", - &BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); - py::class_( - m, "BatchPrefillWithPagedKVCachePyTorchWrapper") - .def(py::init()) - .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) - .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) - .def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) - .def("update_page_locked_buffer_size", - &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward) - .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); - py::class_( - m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") - .def(py::init()) - .def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward) - .def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) - .def("is_cuda_graph_enabled", - &BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled) - .def("update_page_locked_buffer_size", - &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) - .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") .def(py::init()) .def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer) diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 02d6a127..e3edff64 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -16,29 +16,10 @@ #pragma once #include -#include #include #include #include -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta); - -std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse); - -std::vector single_prefill_with_kv_cache_custom_mask( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, - torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); - void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, std::optional paged_kv_cache, std::optional paged_k_cache, @@ -106,100 +87,6 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, torch::Tensor output_indptr, const std::string& bitorder); -class BatchDecodeWithPagedKVCachePyTorchWrapper { - public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, - torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, float logits_soft_cap, - torch::Tensor empty_q_data, torch::Tensor empty_kv_data); - void EndForward(); - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); - bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - std::vector Forward(torch::Tensor q, std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, - unsigned int pos_encoding_mode, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse); - BatchDecodeWithPagedKVCachePyTorchWrapper( - std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) - : handler_(handler_ptr), kv_layout_(kv_layout) {} - BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph, - unsigned int fixed_batch_size) - : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(enable_cuda_graph, - fixed_batch_size)) {} - - protected: - std::shared_ptr handler_; - flashinfer::QKVLayout kv_layout_; -}; - -class BatchPrefillWithPagedKVCachePyTorchWrapper { - public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor page_kv_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, - unsigned page_size, torch::Tensor empty_q_data); - void EndForward(); - bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); - std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, - std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); - std::vector ForwardCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask, - torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse); - BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) - : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(enable_cuda_graph)) {} - - private: - std::shared_ptr handler_; - flashinfer::QKVLayout kv_layout_; -}; - -class BatchPrefillWithRaggedKVCachePyTorchWrapper { - public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data); - void EndForward(); - bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); - std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, - torch::Tensor v, torch::Tensor kv_indptr, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); - std::vector ForwardCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) - : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(enable_cuda_graph)) {} - - private: - std::shared_ptr handler_; - flashinfer::QKVLayout kv_layout_; -}; - class CutlassSegmentGEMMPyTorchWrapper { public: void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer); diff --git a/python/csrc/flashinfer_ops_decode.cu b/python/csrc/flashinfer_ops_decode.cu new file mode 100644 index 00000000..15e3f25a --- /dev/null +++ b/python/csrc/flashinfer_ops_decode.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer_ops_decode.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, + "Single-request decode with KV-Cache operator"); + py::class_(m, + "BatchDecodeWithPagedKVCachePyTorchWrapper") + .def(py::init()) + .def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) + .def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) + .def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) + .def("update_page_locked_buffer_size", + &BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) + .def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); +} diff --git a/python/csrc/flashinfer_ops_decode.h b/python/csrc/flashinfer_ops_decode.h new file mode 100644 index 00000000..1f955a7f --- /dev/null +++ b/python/csrc/flashinfer_ops_decode.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +#include +#include +#include + +torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, unsigned int pos_encoding_mode, + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta); + +class BatchDecodeWithPagedKVCachePyTorchWrapper { + public: + void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, + torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, + unsigned int pos_encoding_mode, float logits_soft_cap, + torch::Tensor empty_q_data, torch::Tensor empty_kv_data); + void EndForward(); + void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } + std::vector Forward(torch::Tensor q, std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, + unsigned int pos_encoding_mode, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, bool return_lse); + BatchDecodeWithPagedKVCachePyTorchWrapper( + std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) + : handler_(handler_ptr), kv_layout_(kv_layout) {} + BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph, + unsigned int fixed_batch_size) + : kv_layout_(flashinfer::QKVLayout(layout)), + handler_(std::make_shared(enable_cuda_graph, + fixed_batch_size)) {} + + protected: + std::shared_ptr handler_; + flashinfer::QKVLayout kv_layout_; +}; diff --git a/python/csrc/flashinfer_ops_prefill.cu b/python/csrc/flashinfer_ops_prefill.cu new file mode 100644 index 00000000..992cf10f --- /dev/null +++ b/python/csrc/flashinfer_ops_prefill.cu @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer_ops_prefill.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, + "Single-request prefill with KV-Cache operator, return logsumexp"); + m.def( + "single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask, + "Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp"); + py::class_( + m, "BatchPrefillWithPagedKVCachePyTorchWrapper") + .def(py::init()) + .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) + .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) + .def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) + .def("update_page_locked_buffer_size", + &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) + .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward) + .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); + py::class_( + m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") + .def(py::init()) + .def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward) + .def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) + .def("is_cuda_graph_enabled", + &BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled) + .def("update_page_locked_buffer_size", + &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) + .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) + .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); +} diff --git a/python/csrc/flashinfer_ops_prefill.h b/python/csrc/flashinfer_ops_prefill.h new file mode 100644 index 00000000..949da9ae --- /dev/null +++ b/python/csrc/flashinfer_ops_prefill.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +#include +#include +#include + +std::vector single_prefill_with_kv_cache( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); + +std::vector single_prefill_with_kv_cache_custom_mask( + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, + torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); + +class BatchPrefillWithPagedKVCachePyTorchWrapper { + public: + void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor page_kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + unsigned page_size, torch::Tensor empty_q_data); + void EndForward(); + bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } + void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, + std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, bool causal, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, std::optional paged_v_cache, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask, + torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); + BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) + : kv_layout_(flashinfer::QKVLayout(layout)), + handler_(std::make_shared(enable_cuda_graph)) {} + + private: + std::shared_ptr handler_; + flashinfer::QKVLayout kv_layout_; +}; + +class BatchPrefillWithRaggedKVCachePyTorchWrapper { + public: + void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data); + void EndForward(); + bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } + void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, + torch::Tensor v, torch::Tensor kv_indptr, bool causal, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, + torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) + : kv_layout_(flashinfer::QKVLayout(layout)), + handler_(std::make_shared(enable_cuda_graph)) {} + + private: + std::shared_ptr handler_; + flashinfer::QKVLayout kv_layout_; +}; diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 10013f9c..abbe81dc 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -15,7 +15,7 @@ */ #include -#include "flashinfer_ops.h" +#include "flashinfer_ops_decode.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 5a38bb6e..320d2c35 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -15,7 +15,7 @@ */ #include -#include "flashinfer_ops.h" +#include "flashinfer_ops_prefill.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 504d82e4..0d5f2bb8 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -20,13 +20,15 @@ # mypy: disable-error-code="attr-defined" try: - from . import _kernels + from . import _decode + from . import _prefill except ImportError as e: import os import logging if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None + _decode = None + _prefill = None logging.warning("Kernels are not loaded in documentation build mode.") else: raise e @@ -172,7 +174,7 @@ def single_decode_with_kv_cache( ) if use_tensor_cores: - out = _kernels.single_prefill_with_kv_cache( + out = _prefill.single_prefill_with_kv_cache( q.unsqueeze(0), k, v, @@ -189,7 +191,7 @@ def single_decode_with_kv_cache( False, # return_lse )[0].squeeze(0) else: - out = _kernels.single_decode_with_kv_cache( + out = _decode.single_decode_with_kv_cache( q, k, v, @@ -353,7 +355,7 @@ def __init__( if use_tensor_cores: self._use_tensor_cores = True - self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( + self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, ) @@ -365,7 +367,7 @@ def __init__( ) else: self._use_tensor_cores = False - self._wrapper = _kernels.BatchDecodeWithPagedKVCachePyTorchWrapper( + self._wrapper = _decode.BatchDecodeWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, self._fixed_batch_size, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 9a149a42..74512d2b 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -21,13 +21,13 @@ # mypy: disable-error-code="attr-defined" try: - from . import _kernels + from . import _prefill except ImportError as e: import os import logging if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None + _prefill = None logging.warning("Kernels are not loaded in documentation build mode.") else: raise e @@ -187,7 +187,7 @@ def single_prefill_with_kv_cache( custom_mask.contiguous().view(-1), bitorder="little" ) if packed_custom_mask is not None: - return _kernels.single_prefill_with_kv_cache_custom_mask( + return _prefill.single_prefill_with_kv_cache_custom_mask( q, k, v, @@ -204,7 +204,7 @@ def single_prefill_with_kv_cache( False, # return lse )[0] else: - return _kernels.single_prefill_with_kv_cache( + return _prefill.single_prefill_with_kv_cache( q, k, v, @@ -372,7 +372,7 @@ def single_prefill_with_kv_cache_return_lse( custom_mask.contiguous().view(-1), bitorder="little" ) if packed_custom_mask is not None: - return _kernels.single_prefill_with_kv_cache_custom_mask( + return _prefill.single_prefill_with_kv_cache_custom_mask( q, k, v, @@ -389,7 +389,7 @@ def single_prefill_with_kv_cache_return_lse( True, # return lse ) else: - return _kernels.single_prefill_with_kv_cache( + return _prefill.single_prefill_with_kv_cache( q, k, v, @@ -604,7 +604,7 @@ def __init__( _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer - self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( + self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, ) @@ -1225,7 +1225,7 @@ def __init__( _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer - self._wrapper = _kernels.BatchPrefillWithRaggedKVCachePyTorchWrapper( + self._wrapper = _prefill.BatchPrefillWithRaggedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, ) diff --git a/python/setup.py b/python/setup.py index 56fe98b8..86a18c18 100644 --- a/python/setup.py +++ b/python/setup.py @@ -59,7 +59,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None: f.write(content) -def get_instantiation_cu() -> List[str]: +def get_instantiation_cu() -> Tuple[List[str], List[str]]: prefix = "csrc/generated" (root / prefix).mkdir(parents=True, exist_ok=True) @@ -99,7 +99,8 @@ def get_instantiation_cu() -> List[str]: if enable_fp8: decode_dtypes.extend(fp8_dtypes) - files = [] + files_decode = [] + files_prefill = [] # single decode files for ( head_dim, @@ -115,7 +116,7 @@ def get_instantiation_cu() -> List[str]: ): dtype_out = dtype_q fname = f"single_decode_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" - files.append(prefix + "/" + fname) + files_decode.append(prefix + "/" + fname) content = generate_single_decode_inst.get_cu_file_str( head_dim, logits_hook, @@ -142,7 +143,7 @@ def get_instantiation_cu() -> List[str]: ): dtype_out = dtype_q fname = f"batch_paged_decode_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" - files.append(prefix + "/" + fname) + files_decode.append(prefix + "/" + fname) content = generate_batch_paged_decode_inst.get_cu_file_str( head_dim, logits_hook, @@ -170,7 +171,7 @@ def get_instantiation_cu() -> List[str]: ): for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): fname = f"single_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" - files.append(prefix + "/" + fname) + files_prefill.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( head_dim, logits_hook, @@ -203,7 +204,7 @@ def get_instantiation_cu() -> List[str]: itertools.product(prefill_dtypes, fp8_dtypes) ): fname = f"batch_paged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" - files.append(prefix + "/" + fname) + files_prefill.append(prefix + "/" + fname) content = generate_batch_paged_prefill_inst.get_cu_file_str( head_dim, logits_hook, @@ -235,7 +236,7 @@ def get_instantiation_cu() -> List[str]: ): for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): fname = f"batch_ragged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" - files.append(prefix + "/" + fname) + files_prefill.append(prefix + "/" + fname) content = generate_batch_ragged_prefill_inst.get_cu_file_str( head_dim, logits_hook, @@ -249,7 +250,7 @@ def get_instantiation_cu() -> List[str]: ) write_if_different(root / prefix / fname, content) - return files + return files_prefill, files_decode def get_version(): @@ -309,48 +310,71 @@ def __init__(self, *args, **kwargs) -> None: if __name__ == "__main__": remove_unwanted_pytorch_nvcc_flags() generate_build_meta() + files_prefill, files_decode = get_instantiation_cu() + include_dirs = [ + str(root.resolve() / "include"), + str( + root.resolve() / "3rdparty" / "cutlass" / "include" + ), # for group gemm + ] + extra_compile_args = { + "cxx": [ + "-O3", + "-Wno-switch-bool", + ], + "nvcc": [ + "-O3", + "-std=c++17", + "--threads", + "1", + "-Xfatbin", + "-compress-all", + ], + } ext_modules = [] ext_modules.append( torch_cpp_ext.CUDAExtension( name="flashinfer._kernels", sources=[ - "csrc/single_decode.cu", - "csrc/single_prefill.cu", "csrc/cascade.cu", "csrc/page.cu", - "csrc/batch_decode.cu", "csrc/flashinfer_ops.cu", - "csrc/batch_prefill.cu", "csrc/sampling.cu", "csrc/norm.cu", "csrc/rope.cu", "csrc/group_gemm.cu", "csrc/quantization.cu", - ] - + get_instantiation_cu(), - include_dirs=[ - str(root.resolve() / "include"), - str( - root.resolve() / "3rdparty" / "cutlass" / "include" - ), # for group gemm ], - extra_compile_args={ - "cxx": [ - "-O3", - "-Wno-switch-bool", - ], - "nvcc": [ - "-O3", - "-std=c++17", - "--threads", - "1", - "-Xfatbin", - "-compress-all", - ], - }, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + ) + ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._decode", + sources=[ + "csrc/single_decode.cu", + "csrc/flashinfer_ops_decode.cu", + "csrc/batch_decode.cu", + ] + + files_decode, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + ) + ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._prefill", + sources=[ + "csrc/single_prefill.cu", + "csrc/flashinfer_ops_prefill.cu", + "csrc/batch_prefill.cu", + ] + + files_prefill, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, ) ) - setuptools.setup( name="flashinfer", version=get_version(),