Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] optimize moe_align_block_size for cuda graph and large num_experts (e.g. DeepSeek-V3) #12222

Merged
merged 7 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
}
} // namespace

template <typename scalar_t>
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids,
int32_t* expert_ids,
Expand All @@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
const size_t start_idx = threadIdx.x * tokens_per_thread;

extern __shared__ int32_t shared_mem[];

int32_t* tokens_cnts =
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
int32_t* cumsum =
shared_mem +
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add the original comments back?


for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
Expand Down Expand Up @@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
block_size) *
block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
*total_tokens_post_pad = (int32_t)cumsum[num_experts];
mgoin marked this conversation as resolved.
Show resolved Hide resolved
}

__syncthreads();
Expand Down Expand Up @@ -224,26 +220,43 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if (num_experts >= 256) {
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_i32 =
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
const int32_t shared_mem_i16 =
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t);

bool use_global_memory = false, use_i16 = false;
mgoin marked this conversation as resolved.
Show resolved Hide resolved
if (shared_mem_i16 > device_max_shared_mem) {
use_global_memory = true;
} else if (shared_mem_i32 > device_max_shared_mem &&
topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
use_i16 = true;
}

if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

const int32_t mem_tokens_cnts =
((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);
auto options_int = torch::TensorOptions()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better than directly use cudaMalloc

.dtype(torch::kInt)
.device(topk_ids.device());
torch::Tensor token_cnts_buffer =
torch::empty({(num_experts + 1) * num_experts}, options_int);
torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int);

auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
Expand All @@ -252,25 +265,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>());
});
} else {
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i16));
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
(void*)kernel, shared_mem_i32));
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def _verify_cuda_graph(self) -> None:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)

MODEL_NOT_SUPPORT_CUDA_GRAPH = ['deepseek_v3', 'mllama']
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
and not self.enforce_eager):
logger.warning(
Expand Down
Loading