Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN int8 fc #16457

Merged
merged 3 commits into from
Oct 13, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
@@ -216,15 +216,15 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);

std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DST, *out_mem.second},
};
if (!full_param.default_param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
fwd->fwd_pd.bias_desc());
args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
args[MKLDNN_ARG_BIAS] = *bias_mem;
}
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
@@ -298,7 +298,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
@@ -317,7 +317,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
ipBwdWeights_pd.diff_weights_desc(),
req[fullc::kWeight]);
std::unordered_map<int, mkldnn::memory> args = {
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
@@ -328,7 +328,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
ipBwdWeights_pd.diff_bias_desc(),
req[fullc::kBias]);
args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
args[MKLDNN_ARG_DIFF_BIAS] = *in_grad_bias.second;
}
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
* \author Ciyong Chen
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
#include "../quantization_utils.h"

@@ -89,33 +89,40 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
auto &fwd = GetFCFwd(param, is_train, data, weight,
param.no_bias ? nullptr : &quantized_bias, out_md);

auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc());
const mkldnn::memory *weight_mem = nullptr;

if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(),
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(),
req[fullc::kOut]);
const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias)
bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());

fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DST, *out_mem.second},
};

const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias) {
bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc());
args[MKLDNN_ARG_BIAS] = *bias_mem;
}

MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <mxnet/ndarray.h>
#include <vector>
10 changes: 5 additions & 5 deletions src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
#include <vector>
#include "quantization_utils.h"
#include "../nn/fully_connected-inl.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
#include "mkldnn/mkldnn_quantized_ops-inl.h"
#endif
@@ -94,7 +94,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_type->size(), num_inputs * 3);
CHECK_EQ(out_type->size(), 3U);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
<< "QuantizedFullyConnected only supports int8/uint8 input, while "
<< in_type->at(0) << " is given.";
@@ -124,7 +124,7 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), num_inputs * 3);
CHECK_EQ(out_attrs->size(), 3U);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
return MKLDNNStorageType(attrs, dev_mask, true,
dispatch_mode, in_attrs, out_attrs);
#else
@@ -292,7 +292,7 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
#endif
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
@@ -341,7 +341,7 @@ and max thresholds representing the threholds for quantizing the float32 output
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", QuantizedFullyConnectedForwardExCPU)
#endif
2 changes: 1 addition & 1 deletion tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
@@ -407,7 +407,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
if is_test_for_native_cpu():
hasMKL = False;
hasMKL = False
for key in os.environ.keys():
if operator.eq(key, "BUILD_TAG"):
if os.environ['BUILD_TAG'].find("MKL") != -1: