diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 24341d63fb1f8..715a1b42841f2 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, } } // namespace -template +template __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, @@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, const size_t start_idx = threadIdx.x * tokens_per_thread; extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = - shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) - int32_t* cumsum = - shared_mem + - (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) + token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1); for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; @@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, block_size) * block_size; } - *total_tokens_post_pad = cumsum[num_experts]; + *total_tokens_post_pad = static_cast(cumsum[num_experts]); } __syncthreads(); @@ -224,26 +220,44 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - // If we have very large number of experts, we can no longer use shared - // memory. - // TODO(simon): the right solution should be calculating the exact right - // amount of shared memory and use that. The num_experts >= 256 is just a - // temporary solution to unblock Deepseek V3. - if (num_experts >= 256) { + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_i32 = + ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + const int32_t shared_mem_i16 = + ((num_thread + 1) * num_experts) * sizeof(uint16_t) + + (num_experts + 1) * sizeof(int32_t); + + bool use_global_memory = false; + bool use_i16 = false; // Use uint16_t for shared memory token counts + if (shared_mem_i16 > device_max_shared_mem) { + use_global_memory = true; + } else if (shared_mem_i32 > device_max_shared_mem && + topk_ids.numel() <= 65535) { + // when nelements of topk_ids is smaller than 65535 (max value of uint16), + // element value of token_cnts would also smaller than 65535, + // so we can use uint16 as dtype of token_cnts + use_i16 = true; + } + + if (use_global_memory) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t mem_tokens_cnts = - ((num_experts + 1) * num_experts) * sizeof(int32_t); - const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); - // allocate global memory - int32_t* tokens_cnts; - int32_t* cumsum; - cudaMalloc(&tokens_cnts, mem_tokens_cnts); - cudaMalloc(&cumsum, mem_cumsum); + auto options_int = torch::TensorOptions() + .dtype(torch::kInt) + .device(topk_ids.device()); + torch::Tensor token_cnts_buffer = + torch::empty({(num_experts + 1) * num_experts}, options_int); + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); auto kernel = vllm::moe::moe_align_block_size_global_mem_kernel; @@ -252,25 +266,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), tokens_cnts, cumsum); - cudaFree(tokens_cnts); - cudaFree(cumsum); + topk_ids.numel(), token_cnts_buffer.data_ptr(), + cumsum_buffer.data_ptr()); }); - } else { + } else if (use_i16) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem = - ((num_thread + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - // set dynamic shared mem - auto kernel = vllm::moe::moe_align_block_size_kernel; + auto kernel = + vllm::moe::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem_i16)); + kernel<<<1, num_thread, shared_mem_i16, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); + } else { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + auto kernel = + vllm::moe::moe_align_block_size_kernel; AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<<1, num_thread, shared_mem, stream>>>( + (void*)kernel, shared_mem_i32)); + kernel<<<1, num_thread, shared_mem_i32, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), diff --git a/vllm/config.py b/vllm/config.py index 4698a05020332..b0a92b2e21343 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -607,7 +607,7 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) - MODEL_NOT_SUPPORT_CUDA_GRAPH = ['deepseek_v3', 'mllama'] + MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama'] if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH and not self.enforce_eager): logger.warning(