Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN] Remove overhead of sg_mkldnn_fullyconnected op #17707

Merged
merged 1 commit into from
Mar 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/operator/subgraph/mkldnn/mkldnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static std::vector<float> GetWeightScales(const NDArray &weight, const NDArray *
}

static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias,
const mkldnn::memory::desc weight_md,
const mkldnn::memory::desc &weight_md,
const mkldnn::memory::desc *bias_md,
const int num_group, float data_scale,
const std::vector<float> &weight_scales,
Expand Down
124 changes: 66 additions & 58 deletions src/operator/subgraph/mkldnn/mkldnn_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ class SgMKLDNNFCOp {

private:
bool initialized_{false};
bool channel_wise_runtime_{false};
bool reorder_data_{false};
nnvm::Symbol subgraph_sym_;
MKLDNNFCFullParam full_param_;
mkldnn_args_map_t args_;
std::shared_ptr<MKLDNNFullyConnectedForward> fwd_;
std::shared_ptr<mkldnn::memory> cached_data_mem_;
std::shared_ptr<mkldnn::memory> cached_out_mem_;
NDArray cached_weight_;
NDArray cached_bias_;
Expand All @@ -82,28 +85,10 @@ class SgMKLDNNFCOp {
float cached_max_output_;
float data_scale_{0.0f};
std::vector<float> weight_scales_;
size_t total_num_inputs_;
size_t total_num_outputs_;
};

static inline void MKLDNNFCFlattenData(const FullyConnectedParam &param,
NDArray *in_data) {
const mxnet::TShape ishape = in_data->shape();

// If the input data is a view of an MKLDNN array, we should create a new
// NDArray with reordered data.
if (in_data->IsMKLDNNData() && in_data->IsView())
*in_data = in_data->Reorder2Default();

auto data_ndim = ishape.ndim();
if (data_ndim != 2) {
if (!param.flatten) {
*in_data = in_data->MKLDNNDataReshape(
Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
} else {
*in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
}
}
}

void SgMKLDNNFCOp::Forward(const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
Expand All @@ -112,9 +97,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
auto &default_param = full_param_.default_param;
bool has_bias = !default_param.no_bias;
size_t base_num_inputs = has_bias ? 3 : 2;
size_t total_num_inputs = base_num_inputs;
size_t base_num_outputs = 1;
size_t total_num_outputs = base_num_outputs;

float min_data = 0.0f;
float max_data = 0.0f;
Expand All @@ -123,17 +106,29 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
float min_bias = 0.0f;
float max_bias = 0.0f;

bool channel_wise = false;
if (mkldnn_param.channel_wise_quantize.has_value() &&
mkldnn_param.channel_wise_quantize) {
channel_wise = true;
if (!initialized_) {
if (mkldnn_param.channel_wise_quantize.has_value() &&
mkldnn_param.channel_wise_quantize) {
channel_wise_runtime_ = true;
}

total_num_inputs_ = base_num_inputs;
total_num_outputs_ = base_num_outputs;
if (mkldnn_param.quantized) {
total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3);
total_num_outputs_ =
mkldnn_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3);
}
}
CHECK_EQ(in_data.size(), total_num_inputs_);
CHECK_EQ(out_data.size(), total_num_outputs_);

NDArray data = in_data[fullc::kData];
const NDArray &weight = in_data[fullc::kWeight];
const NDArray &output = out_data[fullc::kOut];

if (mkldnn_param.quantized) {
if (channel_wise) {
total_num_inputs = base_num_inputs + 2;
} else {
total_num_inputs = base_num_inputs * 3;
if (!channel_wise_runtime_) {
min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr<float>()[0];
max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr<float>()[0];
if (has_bias) {
Expand All @@ -143,20 +138,11 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
}
min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr<float>()[0];
max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr<float>()[0];
if (!mkldnn_param.enable_float_output) {
total_num_outputs = base_num_outputs * 3;
}
}
CHECK_EQ(in_data.size(), total_num_inputs);
CHECK_EQ(out_data.size(), total_num_outputs);

NDArray data = in_data[fullc::kData];
NDArray weight = in_data[fullc::kWeight];
NDArray output = out_data[fullc::kOut];
MKLDNNFCFlattenData(default_param, &data);

if (initialized_ && mkldnn_param.quantized) {
if (channel_wise) {
if (initialized_ && mkldnn_param.quantized &&
dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this new?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this env var could be removed in the future once the time consuming operation of weight_ver_ != weight.version() is resolved.

if (channel_wise_runtime_) {
if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
weight_ver_ != weight.version() ||
(has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
Expand All @@ -173,6 +159,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,

if (!initialized_) {
const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const auto engine = CpuEngine::Get()->get_engine();
cached_min_data_ = min_data;
cached_max_data_ = max_data;
cached_min_weight_ = min_weight;
Expand All @@ -187,9 +174,22 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
} else {
cached_bias_ = NDArray();
}
const mxnet::TShape ishape = data.shape();
const auto data_ndim = ishape.ndim();
if (data.IsMKLDNNData()) {
reorder_data_ = true;
data = data.Reorder2Default();
}
if (data_ndim != 2) {
if (!default_param.flatten) {
data = data.MKLDNNDataReshape(
Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
} else {
data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
}
}

// create cached out_md
const mxnet::TShape ishape = data.shape();
const mxnet::TShape oshape = output.shape();
mkldnn::memory::dims out_dims(2);
if (oshape.ndim() == 2) {
Expand All @@ -206,7 +206,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
}
mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()),
static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(2)));
cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, CpuEngine::Get()->get_engine());
cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, engine);

bool support_channelwise_scale = false;
if (mkldnn_param.quantized) {
Expand All @@ -229,15 +229,15 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
// True True True
// True False Error
// False True/False False
if (channel_wise && !support_channelwise_scale) {
if (channel_wise_runtime_ && !support_channelwise_scale) {
LOG(FATAL)
<< "Currently, channel-wise quantization requires fuse requantize or dequantize."
<< " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
<< " fuse requantize (outputs of FullyConnected are collected during calibration phase),"
<< " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
<< " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
}
support_channelwise_scale = support_channelwise_scale && channel_wise;
support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_;

if (support_channelwise_scale) {
MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
Expand Down Expand Up @@ -329,30 +329,38 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
has_bias ? &bias_md : nullptr,
1, data_scale_, weight_scales_, false);
} else {
cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
auto cached_weight_mem = cached_weight_.GetMKLDNNData();
auto def_weight_mem = weight.GetMKLDNNData();
std::unordered_map<int, mkldnn::memory> args(
{{MKLDNN_ARG_FROM, *def_weight_mem},
{MKLDNN_ARG_TO, *cached_weight_mem}});
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args);
const auto def_weight_mem = weight.GetMKLDNNData();
if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
auto cached_weight_mem = cached_weight_.GetMKLDNNData();
std::unordered_map<int, mkldnn::memory> args(
{{MKLDNN_ARG_FROM, *def_weight_mem},
{MKLDNN_ARG_TO, *cached_weight_mem}});
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args);
}
}

args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
const auto data_mem = data.GetMKLDNNData();
cached_data_mem_ = std::make_shared<mkldnn::memory>(data_mem->get_desc(), engine);

args_[MKLDNN_ARG_SRC] = *cached_data_mem_;
args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
if (has_bias)
args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
args_[MKLDNN_ARG_DST] = *cached_out_mem_;
initialized_ = true;
}

auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_desc());
if (reorder_data_) {
data = data.Reorder2Default();
}
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
cached_data_mem_->set_data_handle(reinterpret_cast<void *>(data.data().dptr<DType>()));
});
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
cached_out_mem_->set_data_handle(reinterpret_cast<void *>(output.data().dptr<DType>()));
});
args_[MKLDNN_ARG_SRC] = *data_mem;
args_[MKLDNN_ARG_DST] = *cached_out_mem_;
MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
MKLDNNStream::Get()->Submit();

Expand Down