-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] optimize moe_align_block_size for cuda graph and large num_experts (e.g. DeepSeek-V3) #12222
Merged
simon-mo
merged 7 commits into
vllm-project:main
from
jinzhen-lin:optimize_moe_align_block_size
Jan 21, 2025
Merged
[Kernel] optimize moe_align_block_size for cuda graph and large num_experts (e.g. DeepSeek-V3) #12222
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
07fa60b
optimize moe_align_block_size
jinzhen-lin 3201616
update config
jinzhen-lin 57abfd4
fix error
jinzhen-lin 4d263c0
fix format error
jinzhen-lin 0ea4deb
Format
mgoin e72c81f
Update csrc/moe/moe_align_sum_kernels.cu
mgoin 49fb023
Update csrc/moe/moe_align_sum_kernels.cu
mgoin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, | |
} | ||
} // namespace | ||
|
||
template <typename scalar_t> | ||
template <typename scalar_t, typename token_cnts_t> | ||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, | ||
int32_t* sorted_token_ids, | ||
int32_t* expert_ids, | ||
|
@@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, | |
const size_t start_idx = threadIdx.x * tokens_per_thread; | ||
|
||
extern __shared__ int32_t shared_mem[]; | ||
|
||
int32_t* tokens_cnts = | ||
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) | ||
int32_t* cumsum = | ||
shared_mem + | ||
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) | ||
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) | ||
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1); | ||
|
||
for (int i = 0; i < num_experts; ++i) { | ||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; | ||
|
@@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, | |
block_size) * | ||
block_size; | ||
} | ||
*total_tokens_post_pad = cumsum[num_experts]; | ||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); | ||
} | ||
|
||
__syncthreads(); | ||
|
@@ -224,26 +220,44 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | |
torch::Tensor num_tokens_post_pad) { | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
// If we have very large number of experts, we can no longer use shared | ||
// memory. | ||
// TODO(simon): the right solution should be calculating the exact right | ||
// amount of shared memory and use that. The num_experts >= 256 is just a | ||
// temporary solution to unblock Deepseek V3. | ||
if (num_experts >= 256) { | ||
int device_max_shared_mem; | ||
auto dev = topk_ids.get_device(); | ||
cudaDeviceGetAttribute(&device_max_shared_mem, | ||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); | ||
|
||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); | ||
const int32_t shared_mem_i32 = | ||
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); | ||
const int32_t shared_mem_i16 = | ||
((num_thread + 1) * num_experts) * sizeof(uint16_t) + | ||
(num_experts + 1) * sizeof(int32_t); | ||
|
||
bool use_global_memory = false; | ||
bool use_i16 = false; // Use uint16_t for shared memory token counts | ||
if (shared_mem_i16 > device_max_shared_mem) { | ||
use_global_memory = true; | ||
} else if (shared_mem_i32 > device_max_shared_mem && | ||
topk_ids.numel() <= 65535) { | ||
// when nelements of topk_ids is smaller than 65535 (max value of uint16), | ||
// element value of token_cnts would also smaller than 65535, | ||
// so we can use uint16 as dtype of token_cnts | ||
use_i16 = true; | ||
} | ||
|
||
if (use_global_memory) { | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { | ||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` | ||
// tensors | ||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); | ||
|
||
const int32_t mem_tokens_cnts = | ||
((num_experts + 1) * num_experts) * sizeof(int32_t); | ||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); | ||
// allocate global memory | ||
int32_t* tokens_cnts; | ||
int32_t* cumsum; | ||
cudaMalloc(&tokens_cnts, mem_tokens_cnts); | ||
cudaMalloc(&cumsum, mem_cumsum); | ||
auto options_int = torch::TensorOptions() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much better than directly use cudaMalloc |
||
.dtype(torch::kInt) | ||
.device(topk_ids.device()); | ||
torch::Tensor token_cnts_buffer = | ||
torch::empty({(num_experts + 1) * num_experts}, options_int); | ||
torch::Tensor cumsum_buffer = | ||
torch::empty({num_experts + 1}, options_int); | ||
|
||
auto kernel = | ||
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>; | ||
|
@@ -252,25 +266,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | |
sorted_token_ids.data_ptr<int32_t>(), | ||
experts_ids.data_ptr<int32_t>(), | ||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||
topk_ids.numel(), tokens_cnts, cumsum); | ||
cudaFree(tokens_cnts); | ||
cudaFree(cumsum); | ||
topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(), | ||
cumsum_buffer.data_ptr<int32_t>()); | ||
}); | ||
} else { | ||
} else if (use_i16) { | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` | ||
// tensors | ||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); | ||
const int32_t shared_mem = | ||
((num_thread + 1) * num_experts + (num_experts + 1)) * | ||
sizeof(int32_t); | ||
|
||
// set dynamic shared mem | ||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>; | ||
auto kernel = | ||
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>; | ||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( | ||
(void*)kernel, shared_mem_i16)); | ||
kernel<<<1, num_thread, shared_mem_i16, stream>>>( | ||
topk_ids.data_ptr<scalar_t>(), | ||
sorted_token_ids.data_ptr<int32_t>(), | ||
experts_ids.data_ptr<int32_t>(), | ||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||
topk_ids.numel()); | ||
}); | ||
} else { | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||
auto kernel = | ||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>; | ||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( | ||
(void*)kernel, shared_mem)); | ||
kernel<<<1, num_thread, shared_mem, stream>>>( | ||
(void*)kernel, shared_mem_i32)); | ||
kernel<<<1, num_thread, shared_mem_i32, stream>>>( | ||
topk_ids.data_ptr<scalar_t>(), | ||
sorted_token_ids.data_ptr<int32_t>(), | ||
experts_ids.data_ptr<int32_t>(), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Add the original comments back?