Skip to content

Commit

Permalink
fix bug and skip fc_add pass in bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglirong1999 committed Oct 24, 2023
1 parent ded5256 commit 88d8b8b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 9 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ void CpuPassStrategy::EnableMkldnnQuantizer() {

void CpuPassStrategy::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_DNNL
EraseFcAddMkldnnPasses();
if (!use_mkldnn_bfloat16_) {
passes_.emplace_back("fc_mkldnn_pass");
passes_.emplace_back("fc_act_mkldnn_fuse_pass");
Expand Down Expand Up @@ -508,6 +509,17 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
}
}

void CpuPassStrategy::EraseFcAddMkldnnPasses() {
std::vector<std::string> fc_add_passes_to_erase(
{"fc_elementwise_add_mkldnn_fuse_pass"});
for (const auto &pass : fc_add_passes_to_erase) {
int idx = static_cast<int>(GetPassIndex(pass));
if (idx != -1) {
passes_.erase(std::begin(passes_) + idx);
}
}
}

XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_assign_op_pass",
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
/// \brief Erase MKLDNN fc passes.
void EraseFcMkldnnPasses();

/// \brief Erase MKLDNN fc_elementise_add passes.
void EraseFcAddMkldnnPasses();

/// \cond Protected
bool use_mkldnn_quantizer_{false};
bool use_mkldnn_bfloat16_{false};
Expand Down
14 changes: 5 additions & 9 deletions paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,10 @@ class FCMKLDNNHandler
this->dev_ctx_.GetBlob(residual_key));
if (!memory_p) {
auto dims = this->fwd_pd_->dst_desc().get_dims();
if (phi::funcs::is_int8<T_in>() || phi::funcs::is_bfloat16<T_in>()) {
constexpr bool is_int8 = phi::funcs::is_int8<T_in>();
auto data_type = dnnl::memory::data_type::bf16;
if (is_int8) {
data_type = residual->dtype() == phi::DataType::INT8
? dnnl::memory::data_type::s8
: dnnl::memory::data_type::u8;
}
if (phi::funcs::is_int8<T_in>()) {
auto data_type = residual->dtype() == phi::DataType::INT8
? dnnl::memory::data_type::s8
: dnnl::memory::data_type::u8;

auto src_0_md =
dnnl::memory::desc(dims, data_type, dnnl::memory::format_tag::ab);
Expand Down Expand Up @@ -603,7 +599,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}

if (phi::funcs::is_int8<T_in>() || phi::funcs::is_bfloat16<T_in>()) {
if (phi::funcs::is_int8<T_in>()) {
handler.SetScalesIfNeeded(&fc_args);
}

Expand Down
1 change: 1 addition & 0 deletions test/mkldnn/test_fc_add_int8_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sys
import unittest

sys.path.append("../legacy_test")
import numpy as np
from op_test import OpTest, OpTestTool
Expand Down
1 change: 1 addition & 0 deletions test/mkldnn/test_fc_add_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sys
import unittest

sys.path.append("../legacy_test")

import numpy as np
Expand Down

0 comments on commit 88d8b8b

Please sign in to comment.