Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Fused Moe Optimization #70059

Merged
merged 13 commits into from
Dec 28, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,16 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
const bool is_weight_only_encoder,
const bool simt_configs_only,
const int sm,
const int group_size) {
const int group_size,
const bool is_moe) {
VLOG(3) << "get_candidate_tiles sm: " << sm;
std::vector<CutlassTileConfig> simt_configs{
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};

std::vector<CutlassTileConfig> square_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
};
std::vector<CutlassTileConfig> quant_B_configs_sm70{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
Expand All @@ -129,6 +131,13 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
};
if (is_moe) {
quant_B_configs_sm80.push_back(
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64);
} else {
quant_B_configs_sm80.push_back(
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64);
}
std::vector<CutlassTileConfig> quant_B_configs_sm80_finegrained{
CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
Expand Down Expand Up @@ -164,13 +173,15 @@ static std::vector<CutlassGemmConfig> get_candidate_configs(
const int group_size,
const bool is_weight_only,
const bool is_weight_only_encoder,
const bool simt_configs_only) {
const bool simt_configs_only,
const bool is_moe) {
std::vector<CutlassTileConfig> tiles =
get_candidate_tiles(is_weight_only,
is_weight_only_encoder,
simt_configs_only,
sm,
group_size);
group_size,
is_moe);

std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
#include <optional>
#include "paddle/common/errors.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/arch_define.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/gemm_config_manager.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic pop
Expand Down Expand Up @@ -285,6 +287,29 @@ void dispatch_gemm_to_cutlass(const T* A,
stream,
occupancy);
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
FineGrained,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<128, 32, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
// config for M_16000_N_12288_K_6144 in encoder
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
dispatch_gemm_config<T,
Expand Down Expand Up @@ -573,41 +598,92 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag,
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
const bool is_weight_only_encoder = m >= 512 ? true : false;
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(
sm_, group_size, is_weight_only, is_weight_only_encoder, false);
std::vector<int> occupancies(candidate_configs.size());
sm_, group_size, is_weight_only, is_weight_only_encoder, false, false);

for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream,
&occupancies[ii]);
}
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular
// FFN.
static constexpr int num_experts = 1;
CutlassGemmConfig chosen_config =
estimate_best_config_from_occupancies(candidate_configs,
occupancies,
m,
n,
k,
group_size,
num_experts,
split_k_limit,
workspace_bytes,
multi_processor_count_,
is_weight_only,
sm_);
static constexpr int warm_time = 5;
static constexpr int test_time = 10;

auto& gemmConfigManager = phi::GemmConfigManager::Instance();
constexpr GemmDataType dtype = getGemmDataType<T>();
constexpr GemmDataType wdtype = getGemmDataType<WeightType>();
GemmIDType gemmId{n, k, GemmType::FPAINTBGEMM, dtype, wdtype, num_experts};
CutlassGemmConfig chosen_config;
auto chosen_config_optional = gemmConfigManager.getBestConfig(gemmId, m);
if (chosen_config_optional != std::nullopt) {
chosen_config = chosen_config_optional.value();
} else {
float best_time = std::numeric_limits<float>::max();
CutlassGemmConfig best_config;
int profile_m = gemmConfigManager.nextPowerOfTwo(m);
bool found_one = false;

for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
for (int i = 0; i < warm_time; i++) {
dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream);
}
cudaEvent_t start;
cudaEvent_t stop;
check_cuda_error(cudaEventCreate(&start));
check_cuda_error(cudaEventCreate(&stop));
check_cuda_error(cudaStreamSynchronize(stream));
check_cuda_error(cudaEventRecord(start, stream));
for (int i = 0; i < test_time; i++) {
dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
group_size,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream);
}
check_cuda_error(cudaEventRecord(stop, stream));
check_cuda_error(cudaEventSynchronize(stop));
found_one = true;
float elapsed;
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
check_cuda_error(cudaEventDestroy(start));
check_cuda_error(cudaEventDestroy(stop));
if (elapsed < best_time) {
best_time = elapsed;
best_config = candidate_configs[ii];
}
VLOG(4) << "profile_m" << profile_m;
VLOG(4) << "candidate_config tile_config"
<< static_cast<int>(candidate_configs[ii].tile_config);
VLOG(4) << "candidate_config split_k_style"
<< static_cast<int>(candidate_configs[ii].split_k_style);
VLOG(4) << "candidate_config split_k_factor "
<< candidate_configs[ii].split_k_factor;
VLOG(4) << "candidate_config stages " << candidate_configs[ii].stages;
VLOG(4) << "elapsed time: " << elapsed;
VLOG(4) << "best_time: " << best_time;
}
if (found_one) {
gemmConfigManager.addBestConfig(gemmId, profile_m, best_config);
chosen_config = best_config;
}
}

dispatch_to_arch<EpilogueTag, FineGrained>(A,
B,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"cutlass::gemm::GemmShape<32, 128, 64>",
"cutlass::gemm::GemmShape<64, 128, 64>",
"cutlass::gemm::GemmShape<128, 128, 64>",
"cutlass::gemm::GemmShape<128, 128, 64>",
"cutlass::gemm::GemmShape<128, 256, 64>",
"cutlass::gemm::GemmShape<256, 128, 64>",
]
Expand All @@ -98,6 +99,7 @@
"cutlass::gemm::GemmShape<32, 32, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
"cutlass::gemm::GemmShape<128, 32, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
]
Expand Down
Loading
Loading