Skip to content

Commit

Permalink
fix format error
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhen-lin committed Jan 20, 2025
1 parent a090af9 commit 4fe426e
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "../cuda_compat.h"
#include "../dispatch_utils.h"

#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))

namespace vllm {
namespace moe {
Expand All @@ -33,7 +33,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,

extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
token_cnts_t* tokens_cnts = (token_cnts_t *) (shared_mem + blockDim.x + 1);
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);

for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
Expand Down Expand Up @@ -70,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
block_size) *
block_size;
}
*total_tokens_post_pad = (int32_t) cumsum[num_experts];
*total_tokens_post_pad = (int32_t)cumsum[num_experts];
}

__syncthreads();
Expand Down Expand Up @@ -222,20 +222,21 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,

int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_i32 =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
const int32_t shared_mem_i16 =
((num_thread + 1) * num_experts) * sizeof(uint16_t) + (num_experts + 1) *
sizeof(int32_t);
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t);

bool use_global_memory = false, use_i16 = false;
if (shared_mem_i16 > device_max_shared_mem) {
use_global_memory = true;
} else if (shared_mem_i32 > device_max_shared_mem && topk_ids.numel() <= 65535) {
} else if (shared_mem_i32 > device_max_shared_mem &&
topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
Expand All @@ -249,8 +250,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
auto options_int = torch::TensorOptions()
.dtype(torch::kInt)
.device(topk_ids.device());
torch::Tensor token_cnts_buffer =
torch::empty({(num_experts + 1) * num_experts}, options_int);
torch::Tensor cumsum_buffer =
Expand All @@ -270,7 +272,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i16));
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
Expand All @@ -282,8 +285,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i32));
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
Expand Down

0 comments on commit 4fe426e

Please sign in to comment.