Skip to content

Commit

Permalink
[Inference] Fused Moe Optimization (#70059)
Browse files Browse the repository at this point in the history
* add gemm_config_manager
* add serialize & deserialize to support get profile from json
  • Loading branch information
CJ77Qi authored Dec 28, 2024
1 parent e209215 commit fc4ff25
Show file tree
Hide file tree
Showing 5 changed files with 494 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,16 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
const bool is_weight_only_encoder,
const bool simt_configs_only,
const int sm,
const int group_size) {
const int group_size,
const bool is_moe) {
VLOG(3) << "get_candidate_tiles sm: " << sm;
std::vector<CutlassTileConfig> simt_configs{
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};

std::vector<CutlassTileConfig> square_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
};
std::vector<CutlassTileConfig> quant_B_configs_sm70{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
Expand All @@ -129,6 +131,13 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
};
if (is_moe) {
quant_B_configs_sm80.push_back(
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64);
} else {
quant_B_configs_sm80.push_back(
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64);
}
std::vector<CutlassTileConfig> quant_B_configs_sm80_finegrained{
CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
Expand Down Expand Up @@ -164,13 +173,15 @@ static std::vector<CutlassGemmConfig> get_candidate_configs(
const int group_size,
const bool is_weight_only,
const bool is_weight_only_encoder,
const bool simt_configs_only) {
const bool simt_configs_only,
const bool is_moe) {
std::vector<CutlassTileConfig> tiles =
get_candidate_tiles(is_weight_only,
is_weight_only_encoder,
simt_configs_only,
sm,
group_size);
group_size,
is_moe);

std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
#include <optional>
#include "paddle/common/errors.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/arch_define.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/gemm_config_manager.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic pop
Expand Down Expand Up @@ -285,6 +287,29 @@ void dispatch_gemm_to_cutlass(const T* A,
stream,
occupancy);
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
FineGrained,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<128, 32, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
// config for M_16000_N_12288_K_6144 in encoder
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
dispatch_gemm_config<T,
Expand Down Expand Up @@ -573,41 +598,92 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag,
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
const bool is_weight_only_encoder = m >= 512 ? true : false;
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(
sm_, group_size, is_weight_only, is_weight_only_encoder, false);
std::vector<int> occupancies(candidate_configs.size());
sm_, group_size, is_weight_only, is_weight_only_encoder, false, false);

for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream,
&occupancies[ii]);
}
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular
// FFN.
static constexpr int num_experts = 1;
CutlassGemmConfig chosen_config =
estimate_best_config_from_occupancies(candidate_configs,
occupancies,
m,
n,
k,
group_size,
num_experts,
split_k_limit,
workspace_bytes,
multi_processor_count_,
is_weight_only,
sm_);
static constexpr int warm_time = 5;
static constexpr int test_time = 10;

auto& gemmConfigManager = phi::GemmConfigManager::Instance();
constexpr GemmDataType dtype = getGemmDataType<T>();
constexpr GemmDataType wdtype = getGemmDataType<WeightType>();
GemmIDType gemmId{n, k, GemmType::FPAINTBGEMM, dtype, wdtype, num_experts};
CutlassGemmConfig chosen_config;
auto chosen_config_optional = gemmConfigManager.getBestConfig(gemmId, m);
if (chosen_config_optional != std::nullopt) {
chosen_config = chosen_config_optional.value();
} else {
float best_time = std::numeric_limits<float>::max();
CutlassGemmConfig best_config;
int profile_m = gemmConfigManager.nextPowerOfTwo(m);
bool found_one = false;

for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
for (int i = 0; i < warm_time; i++) {
dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream);
}
cudaEvent_t start;
cudaEvent_t stop;
check_cuda_error(cudaEventCreate(&start));
check_cuda_error(cudaEventCreate(&stop));
check_cuda_error(cudaStreamSynchronize(stream));
check_cuda_error(cudaEventRecord(start, stream));
for (int i = 0; i < test_time; i++) {
dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream);
}
check_cuda_error(cudaEventRecord(stop, stream));
check_cuda_error(cudaEventSynchronize(stop));
found_one = true;
float elapsed;
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
check_cuda_error(cudaEventDestroy(start));
check_cuda_error(cudaEventDestroy(stop));
if (elapsed < best_time) {
best_time = elapsed;
best_config = candidate_configs[ii];
}
VLOG(4) << "profile_m" << profile_m;
VLOG(4) << "candidate_config tile_config"
<< static_cast<int>(candidate_configs[ii].tile_config);
VLOG(4) << "candidate_config split_k_style"
<< static_cast<int>(candidate_configs[ii].split_k_style);
VLOG(4) << "candidate_config split_k_factor "
<< candidate_configs[ii].split_k_factor;
VLOG(4) << "candidate_config stages " << candidate_configs[ii].stages;
VLOG(4) << "elapsed time: " << elapsed;
VLOG(4) << "best_time: " << best_time;
}
if (found_one) {
gemmConfigManager.addBestConfig(gemmId, profile_m, best_config);
chosen_config = best_config;
}
}

dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"cutlass::gemm::GemmShape<32, 128, 64>",
"cutlass::gemm::GemmShape<64, 128, 64>",
"cutlass::gemm::GemmShape<128, 128, 64>",
"cutlass::gemm::GemmShape<128, 128, 64>",
"cutlass::gemm::GemmShape<128, 256, 64>",
"cutlass::gemm::GemmShape<256, 128, 64>",
]
Expand All @@ -98,6 +99,7 @@
"cutlass::gemm::GemmShape<32, 32, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
"cutlass::gemm::GemmShape<128, 32, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
]
Expand Down
Loading

0 comments on commit fc4ff25

Please sign in to comment.