From 1c20a8779204854a039a1e82fdc1052bb95e03e8 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Tue, 16 Oct 2018 13:43:28 +0800 Subject: [PATCH 01/14] support quantilized fc in cpu --- .../mkldnn_quantized_fully_connected.cc | 134 ++++++++++++++++++ .../quantization/quantized_fully_connected.cc | 17 +++ .../python/quantization/test_quantization.py | 25 ++-- 3 files changed, 164 insertions(+), 12 deletions(-) create mode 100644 src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc new file mode 100644 index 000000000000..c39c33c92fca --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../quantization_utils.h" +#include "../../nn/fully_connected-inl.h" + +namespace mxnet { +namespace op { + +// value + bias_value * (range1 / limit_range1) * (limit_range2 / range2) +struct QuantizedBiasAddKernel { + MSHADOW_XINLINE static void Map(int i, size_t k, int32_t *out, + const int8_t *bias, const float *min_out, + const float *max_out, const float *min_bias, + const float *max_bias) { + typedef int32_t T1; + typedef int8_t T2; + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float float_for_one_out_quant = + MaxAbs(*min_out, *max_out) / static_cast(MaxValue()); + float float_for_one_bias_quant = + MaxAbs(*min_bias, *max_bias) / static_cast(MaxValue()); + if (float_for_one_out_quant != 0) { + out[i] = (out[i] * float_for_one_out_quant + + bias[i%k] * float_for_one_bias_quant) / + float_for_one_out_quant; + } else { + LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !"; + } + } +}; + +template +void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + using namespace mxnet_op; + size_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(in_data.size(), num_inputs * 3); + CHECK_EQ(out_data.size(), 3U); + const NDArray& data = in_data[0]; + const NDArray& weight = in_data[1]; + const NDArray& out = out_data[0]; + TShape dshape = data.shape(); + TShape wshape = weight.shape(); + TShape oshape = out.shape(); + + CHECK(in_data[0].dtype() == mshadow::kInt8 + && in_data[1].dtype() == mshadow::kInt8) + << "mkldnn_quantized_FullyConnected op only supports int8 as input type"; + + const float alpha = 1.0f; + const float beta = 0.0f; + const CBLAS_OFFSET offsetc = CblasFixOffset; + const MKL_INT8 oa = -128; + const MKL_INT8 ob = 0; + MKL_INT32 oc = 0; + const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + uint8_t* pDataNewRange = reinterpret_cast(malloc(m*k*sizeof(uint8_t))); + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * k; i++) { + pDataNewRange[i] = data.data().dptr()[i] + 128; + } + + cblas_gemm_s8u8s32(CblasRowMajor, + CblasNoTrans, + CblasTrans, + offsetc, + m, + n, + k, + alpha, + pDataNewRange, + k, + oa, + weight.data().dptr(), + k, + ob, + beta, + out.data().dptr(), + n, + &oc); + + free(pDataNewRange); + Stream *s = ctx.get_stream(); + Kernel::Launch(s, 1, + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), + in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); + + if (!param.no_bias) { + const NDArray& bias = in_data[2]; + Kernel::Launch(s, out.shape().Size(), + n, out.data().dptr(), bias.data().dptr(), + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[7].data().dptr(), in_data[8].data().dptr()); + } +} + +NNVM_REGISTER_OP(_contrib_quantized_fully_connected) +.set_attr("FComputeEx", + MKLDNNQuantizedFullyConnectedForward); + + +} // namespace op +} // namespace mxnet +#endif + diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index e334fe7ec9b2..72ad19f11d2e 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -79,6 +79,22 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, return true; } +bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + NNVM_REGISTER_OP(_contrib_quantized_fully_connected) .describe(R"code(Fully Connected operator for input, weight and bias data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type @@ -112,6 +128,7 @@ and max thresholds representing the threholds for quantizing the float32 output }) .set_attr("FInferShape", QuantizedFullyConnectedShape) .set_attr("FInferType", QuantizedFullyConnectedType) +.set_attr("FInferStorageType", QuantizedFullyConnectedStorageType) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "weight.") diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 5ae2c6c398e9..d85cf2c4eeed 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -269,10 +269,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p @with_seed() def test_quantized_fc(): def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): - if mx.current_context().device_type != 'gpu': - print('skipped testing quantized_fc on cpu since it is not supported yet') - return - elif qdtype == 'uint8' and is_test_for_gpu(): + if qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') return @@ -283,17 +280,17 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') if qdtype == 'uint8': data_low = 0.0 - data_high = 127.0 + data_high = 63.0 else: - data_low = -127.0 - data_high = 127.0 + data_low = -63.0 + data_high = 63.0 fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, - shape=data_shape).astype('int32') - fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, - shape=arg_shapes[1]).astype('int32') + shape=data_shape).astype('int8') + fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[1]).astype('int8') if not no_bias: - fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, - shape=arg_shapes[2]).astype('int32') + fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[2]).astype('int8') output = fc_fp32_exe.forward()[0] qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') @@ -335,6 +332,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): check_quantized_fc((32, 111, 2, 2), 100, True, qdtype) check_quantized_fc((32, 512, 2, 2), 100, False, qdtype) check_quantized_fc((32, 111, 2, 2), 100, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, True, qdtype) @with_seed() def test_quantized_flatten(): From 6aaa4c117eb87127fdb179d74f873bce95e8d78c Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Tue, 16 Oct 2018 15:26:29 +0800 Subject: [PATCH 02/14] fix typo bug and register resource for shift data buffer to avoid malloc/delete in each run --- .../mkldnn_quantized_fully_connected.cc | 33 ++++++++----------- .../quantization/quantized_fully_connected.cc | 11 ++++--- .../python/quantization/test_quantization.py | 2 +- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index c39c33c92fca..0ac807c1af47 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -50,29 +50,26 @@ struct QuantizedBiasAddKernel { } }; -template +template void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { const FullyConnectedParam& param = nnvm::get(attrs.parsed); using namespace mshadow; using namespace mxnet_op; size_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), num_inputs * 3); - CHECK_EQ(out_data.size(), 3U); - const NDArray& data = in_data[0]; - const NDArray& weight = in_data[1]; - const NDArray& out = out_data[0]; + CHECK_EQ(out_data.size(), 4U); + const NDArray& data = in_data[0]; + const NDArray& weight = in_data[1]; + const NDArray& out = out_data[0]; + const NDArray& shift_data = out_data[3]; TShape dshape = data.shape(); TShape wshape = weight.shape(); TShape oshape = out.shape(); - CHECK(in_data[0].dtype() == mshadow::kInt8 - && in_data[1].dtype() == mshadow::kInt8) - << "mkldnn_quantized_FullyConnected op only supports int8 as input type"; - const float alpha = 1.0f; const float beta = 0.0f; const CBLAS_OFFSET offsetc = CblasFixOffset; @@ -80,12 +77,9 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, const MKL_INT8 ob = 0; MKL_INT32 oc = 0; const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - uint8_t* pDataNewRange = reinterpret_cast(malloc(m*k*sizeof(uint8_t))); - #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < m * k; i++) { - pDataNewRange[i] = data.data().dptr()[i] + 128; + shift_data.data().dptr()[i] = data.data().dptr()[i] + 128; } cblas_gemm_s8u8s32(CblasRowMajor, @@ -96,7 +90,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, n, k, alpha, - pDataNewRange, + shift_data.data().dptr(), k, oa, weight.data().dptr(), @@ -107,7 +101,6 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, n, &oc); - free(pDataNewRange); Stream *s = ctx.get_stream(); Kernel::Launch(s, 1, out_data[1].data().dptr(), out_data[2].data().dptr(), @@ -125,7 +118,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(_contrib_quantized_fully_connected) .set_attr("FComputeEx", - MKLDNNQuantizedFullyConnectedForward); + MKLDNNQuantizedFullyConnectedForward); } // namespace op diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 72ad19f11d2e..98f7f6a1dafb 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -36,7 +36,7 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; uint32_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_shape->size(), num_inputs * 3); - CHECK_EQ(out_shape->size(), 3U); + CHECK_EQ(out_shape->size(), 4U); CHECK(!shape_is_none(in_shape->at(0))) << "QuantizedFullyConnectedOp input data shape must be given"; @@ -55,6 +55,7 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_shape, 0, TShape({dshape[0], wshape[0]})); SHAPE_ASSIGN_CHECK(*out_shape, 1, TShape({1})); SHAPE_ASSIGN_CHECK(*out_shape, 2, TShape({1})); + SHAPE_ASSIGN_CHECK(*out_shape, 3, dshape); return true; } @@ -64,7 +65,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, const FullyConnectedParam& param = nnvm::get(attrs.parsed); uint32_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_type->size(), num_inputs * 3); - CHECK_EQ(out_type->size(), 3U); + CHECK_EQ(out_type->size(), 4U); for (size_t i = 0; i < num_inputs; ++i) { TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8); @@ -76,6 +77,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt32); TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_type, 3, mshadow::kUint8); return true; } @@ -109,7 +111,7 @@ and max thresholds representing the threholds for quantizing the float32 output const FullyConnectedParam& param = nnvm::get(attrs.parsed); return param.no_bias? 6 : 9; }) -.set_num_outputs(3) +.set_num_outputs(4) .set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { @@ -124,12 +126,13 @@ and max thresholds representing the threholds for quantizing the float32 output }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { - return std::vector{"output", "min_output", "max_output"}; + return std::vector{"output", "min_output", "max_output", "shift_data"}; }) .set_attr("FInferShape", QuantizedFullyConnectedShape) .set_attr("FInferType", QuantizedFullyConnectedType) .set_attr("FInferStorageType", QuantizedFullyConnectedStorageType) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("shiftdata", "NDArray-or-Symbol", "Input data.") .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "weight.") .add_argument("bias", "NDArray-or-Symbol", "bias.") diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index d85cf2c4eeed..5e3c5c2844ad 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -317,7 +317,7 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range - qoutput, min_range, max_range = fc_int8_exe.forward() + qoutput, min_range, max_range, shift_data = fc_int8_exe.forward() if no_bias: assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) From d62481dbced4488f6cc16012c7a49b1571148811 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Tue, 16 Oct 2018 16:12:10 +0800 Subject: [PATCH 03/14] add comment for s8u8 input --- .../quantization/mkldnn/mkldnn_quantized_fully_connected.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 0ac807c1af47..0e20ac43b98b 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -79,6 +79,8 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); for (int i = 0; i < m * k; i++) { + // cblas_gemm_s8u8s32 required first matrix must be uint8 + // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) shift_data.data().dptr()[i] = data.data().dptr()[i] + 128; } From 408009226c6ce7c71beaab76b911d8232c76a938 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Wed, 17 Oct 2018 09:34:45 +0800 Subject: [PATCH 04/14] optimized shift --- .../mkldnn_quantized_fully_connected.cc | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 0e20ac43b98b..e25a46593dcd 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -50,6 +50,13 @@ struct QuantizedBiasAddKernel { } }; +struct QuantizedShiftKernel { + MSHADOW_XINLINE static void Map(int i, int8_t *in, uint8_t *out, int shift) { + out[i] = in[i] + shift; + } +}; + + template void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -78,11 +85,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, MKL_INT32 oc = 0; const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); - for (int i = 0; i < m * k; i++) { - // cblas_gemm_s8u8s32 required first matrix must be uint8 - // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) - shift_data.data().dptr()[i] = data.data().dptr()[i] + 128; - } + Stream *s = ctx.get_stream(); + // cblas_gemm_s8u8s32 required first matrix must be uint8 + // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) + Kernel::Launch(s, m * k, data.data().dptr(), + shift_data.data().dptr(), 128); cblas_gemm_s8u8s32(CblasRowMajor, CblasNoTrans, @@ -102,19 +109,18 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, out.data().dptr(), n, &oc); - - Stream *s = ctx.get_stream(); + Kernel::Launch(s, 1, - out_data[1].data().dptr(), out_data[2].data().dptr(), - in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), - in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), + in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); if (!param.no_bias) { const NDArray& bias = in_data[2]; Kernel::Launch(s, out.shape().Size(), n, out.data().dptr(), bias.data().dptr(), out_data[1].data().dptr(), out_data[2].data().dptr(), - in_data[7].data().dptr(), in_data[8].data().dptr()); + in_data[7].data().dptr(), in_data[8].data().dptr()); } } From e03378b8efa27139752dd56f5cad9307f2114841 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Wed, 17 Oct 2018 09:38:07 +0800 Subject: [PATCH 05/14] fix lint issue --- .../quantization/mkldnn/mkldnn_quantized_fully_connected.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index e25a46593dcd..29c18c96f329 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -88,7 +88,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); // cblas_gemm_s8u8s32 required first matrix must be uint8 // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) - Kernel::Launch(s, m * k, data.data().dptr(), + Kernel::Launch(s, m * k, data.data().dptr(), shift_data.data().dptr(), 128); cblas_gemm_s8u8s32(CblasRowMajor, @@ -109,7 +109,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, out.data().dptr(), n, &oc); - + Kernel::Launch(s, 1, out_data[1].data().dptr(), out_data[2].data().dptr(), in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), From ffc8064009f1d0b31f88d65ebadfba46ba19543c Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Wed, 17 Oct 2018 12:36:12 +0800 Subject: [PATCH 06/14] optimized for bias and offset --- .../mkldnn_quantized_fully_connected.cc | 75 ++++++++++++------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 29c18c96f329..5153fb9e6ab8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -26,9 +26,15 @@ namespace mxnet { namespace op { -// value + bias_value * (range1 / limit_range1) * (limit_range2 / range2) -struct QuantizedBiasAddKernel { - MSHADOW_XINLINE static void Map(int i, size_t k, int32_t *out, +struct QuantizedShiftKernel { + MSHADOW_XINLINE static void Map(int i, int8_t *in, uint8_t *out, int shift) { + out[i] = in[i] + shift; + } +}; + +struct QuantizedSumInitKernelWithBias { + // init sum data with bias for matrix b (n) + MSHADOW_XINLINE static void Map(int i, int32_t *out, const int8_t *bias, const float *min_out, const float *max_out, const float *min_bias, const float *max_bias) { @@ -41,21 +47,35 @@ struct QuantizedBiasAddKernel { float float_for_one_bias_quant = MaxAbs(*min_bias, *max_bias) / static_cast(MaxValue()); if (float_for_one_out_quant != 0) { - out[i] = (out[i] * float_for_one_out_quant + - bias[i%k] * float_for_one_bias_quant) / + out[i] = bias[i] * float_for_one_bias_quant / float_for_one_out_quant; } else { LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !"; + out[i] = 0; } } }; -struct QuantizedShiftKernel { - MSHADOW_XINLINE static void Map(int i, int8_t *in, uint8_t *out, int shift) { - out[i] = in[i] + shift; +struct QuantizedSumInitKernel { + // init sum data for matrix b (n) + MSHADOW_XINLINE static void Map(int i, int32_t *out) { + out[i] = 0; + } +}; + +struct QuantizedSumKernel { + // get sum data(n) for matrix b (n * k) + MSHADOW_XINLINE static void Map(int i, size_t k, int8_t *in, int32_t *out, int shift) { + out[i / k] -= shift * in[i]; } }; +struct QuantizedBetaCKernel { + // prepare beta C (from n to m * n) + MSHADOW_XINLINE static void Map(int i, size_t n, int32_t *out) { + out[i] = out[i % n]; + } +}; template void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, @@ -78,18 +98,35 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, TShape oshape = out.shape(); const float alpha = 1.0f; - const float beta = 0.0f; + const float beta = 1.0f; const CBLAS_OFFSET offsetc = CblasFixOffset; - const MKL_INT8 oa = -128; + const MKL_INT8 oa = 0; const MKL_INT8 ob = 0; MKL_INT32 oc = 0; const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); - Stream *s = ctx.get_stream(); // cblas_gemm_s8u8s32 required first matrix must be uint8 // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) + int shift = 128; Kernel::Launch(s, m * k, data.data().dptr(), - shift_data.data().dptr(), 128); + shift_data.data().dptr(), shift); + Kernel::Launch(s, 1, + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), + in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); + if (!param.no_bias) { + const NDArray& bias = in_data[2]; + Kernel::Launch(s, n, out.data().dptr(), + bias.data().dptr(), out_data[1].data().dptr(), + out_data[2].data().dptr(), in_data[7].data().dptr(), + in_data[8].data().dptr()); + } else { + Kernel::Launch(s, n, out.data().dptr()); + } + Kernel::Launch(s, n * k, k, weight.data().dptr(), + out.data().dptr(), shift); + + Kernel::Launch(s, m * n, n, out.data().dptr()); cblas_gemm_s8u8s32(CblasRowMajor, CblasNoTrans, @@ -109,19 +146,6 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, out.data().dptr(), n, &oc); - - Kernel::Launch(s, 1, - out_data[1].data().dptr(), out_data[2].data().dptr(), - in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), - in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); - - if (!param.no_bias) { - const NDArray& bias = in_data[2]; - Kernel::Launch(s, out.shape().Size(), - n, out.data().dptr(), bias.data().dptr(), - out_data[1].data().dptr(), out_data[2].data().dptr(), - in_data[7].data().dptr(), in_data[8].data().dptr()); - } } NNVM_REGISTER_OP(_contrib_quantized_fully_connected) @@ -132,4 +156,3 @@ NNVM_REGISTER_OP(_contrib_quantized_fully_connected) } // namespace op } // namespace mxnet #endif - From cf9c0744bc4e33c8bddf8a51b7e16b50ca3da9d9 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Wed, 17 Oct 2018 16:40:09 +0800 Subject: [PATCH 07/14] fix typo bug --- src/operator/quantization/quantized_fully_connected.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 98f7f6a1dafb..eead40a586b0 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -132,7 +132,6 @@ and max thresholds representing the threholds for quantizing the float32 output .set_attr("FInferType", QuantizedFullyConnectedType) .set_attr("FInferStorageType", QuantizedFullyConnectedStorageType) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) -.add_argument("shiftdata", "NDArray-or-Symbol", "Input data.") .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "weight.") .add_argument("bias", "NDArray-or-Symbol", "bias.") From f84668c38a04ed8be7ad647fb2e1e464bec0b127 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Thu, 18 Oct 2018 12:05:17 +0800 Subject: [PATCH 08/14] fix resource registration issue --- .../mkldnn_quantized_fully_connected.cc | 21 +++++++++++++------ .../quantization/quantized_fully_connected.cc | 10 ++++----- .../python/quantization/test_quantization.py | 2 +- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 5153fb9e6ab8..be6c29d8d1ce 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -26,6 +26,10 @@ namespace mxnet { namespace op { +namespace qfc { +enum QfcOpResource {kTempSpace}; +} + struct QuantizedShiftKernel { MSHADOW_XINLINE static void Map(int i, int8_t *in, uint8_t *out, int shift) { out[i] = in[i] + shift; @@ -88,11 +92,10 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; size_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), num_inputs * 3); - CHECK_EQ(out_data.size(), 4U); + CHECK_EQ(out_data.size(), 3U); const NDArray& data = in_data[0]; const NDArray& weight = in_data[1]; const NDArray& out = out_data[0]; - const NDArray& shift_data = out_data[3]; TShape dshape = data.shape(); TShape wshape = weight.shape(); TShape oshape = out.shape(); @@ -108,8 +111,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, // cblas_gemm_s8u8s32 required first matrix must be uint8 // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) int shift = 128; + Tensor shiftdata = + ctx.requested[qfc::kTempSpace].get_space_typed( + Shape1(m * k), s); Kernel::Launch(s, m * k, data.data().dptr(), - shift_data.data().dptr(), shift); + shiftdata.dptr_, shift); Kernel::Launch(s, 1, out_data[1].data().dptr(), out_data[2].data().dptr(), in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), @@ -136,7 +142,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, n, k, alpha, - shift_data.data().dptr(), + shiftdata.dptr_, k, oa, weight.data().dptr(), @@ -150,8 +156,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(_contrib_quantized_fully_connected) .set_attr("FComputeEx", - MKLDNNQuantizedFullyConnectedForward); - + MKLDNNQuantizedFullyConnectedForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }); } // namespace op } // namespace mxnet diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index eead40a586b0..72ad19f11d2e 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -36,7 +36,7 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; uint32_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_shape->size(), num_inputs * 3); - CHECK_EQ(out_shape->size(), 4U); + CHECK_EQ(out_shape->size(), 3U); CHECK(!shape_is_none(in_shape->at(0))) << "QuantizedFullyConnectedOp input data shape must be given"; @@ -55,7 +55,6 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_shape, 0, TShape({dshape[0], wshape[0]})); SHAPE_ASSIGN_CHECK(*out_shape, 1, TShape({1})); SHAPE_ASSIGN_CHECK(*out_shape, 2, TShape({1})); - SHAPE_ASSIGN_CHECK(*out_shape, 3, dshape); return true; } @@ -65,7 +64,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, const FullyConnectedParam& param = nnvm::get(attrs.parsed); uint32_t num_inputs = param.no_bias ? 2 : 3; CHECK_EQ(in_type->size(), num_inputs * 3); - CHECK_EQ(out_type->size(), 4U); + CHECK_EQ(out_type->size(), 3U); for (size_t i = 0; i < num_inputs; ++i) { TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8); @@ -77,7 +76,6 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt32); TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32); TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32); - TYPE_ASSIGN_CHECK(*out_type, 3, mshadow::kUint8); return true; } @@ -111,7 +109,7 @@ and max thresholds representing the threholds for quantizing the float32 output const FullyConnectedParam& param = nnvm::get(attrs.parsed); return param.no_bias? 6 : 9; }) -.set_num_outputs(4) +.set_num_outputs(3) .set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { @@ -126,7 +124,7 @@ and max thresholds representing the threholds for quantizing the float32 output }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { - return std::vector{"output", "min_output", "max_output", "shift_data"}; + return std::vector{"output", "min_output", "max_output"}; }) .set_attr("FInferShape", QuantizedFullyConnectedShape) .set_attr("FInferType", QuantizedFullyConnectedType) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 5e3c5c2844ad..d85cf2c4eeed 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -317,7 +317,7 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range - qoutput, min_range, max_range, shift_data = fc_int8_exe.forward() + qoutput, min_range, max_range = fc_int8_exe.forward() if no_bias: assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) From d4350e57c8c270524ff09749d27a32f06e0b9624 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Oct 2018 12:24:13 +0800 Subject: [PATCH 09/14] fix typo bug --- .../mkldnn/mkldnn_quantized_fully_connected.cc | 12 ++++++------ .../quantization/quantized_fully_connected.cc | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index be6c29d8d1ce..f91dcf0d2010 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -26,8 +26,8 @@ namespace mxnet { namespace op { -namespace qfc { -enum QfcOpResource {kTempSpace}; +namespace quantilizedfc { +enum QuantilizedfcOpResource {kTempSpace}; } struct QuantizedShiftKernel { @@ -112,9 +112,9 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) int shift = 128; Tensor shiftdata = - ctx.requested[qfc::kTempSpace].get_space_typed( + ctx.requested[quantilizedfc::kTempSpace].get_space_typed( Shape1(m * k), s); - Kernel::Launch(s, m * k, data.data().dptr(), + Kernel::Launch(s, m * k, data.data().dptr(), shiftdata.dptr_, shift); Kernel::Launch(s, 1, out_data[1].data().dptr(), out_data[2].data().dptr(), @@ -129,7 +129,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, } else { Kernel::Launch(s, n, out.data().dptr()); } - Kernel::Launch(s, n * k, k, weight.data().dptr(), + Kernel::Launch(s, n * k, k, weight.data().dptr(), out.data().dptr(), shift); Kernel::Launch(s, m * n, n, out.data().dptr()); @@ -145,7 +145,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, shiftdata.dptr_, k, oa, - weight.data().dptr(), + weight.data().dptr(), k, ob, beta, diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 72ad19f11d2e..d596fd5bdcb0 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -80,10 +80,10 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, } bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { *dispatch_mode = DispatchMode::kFCompute; #if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask) { From a3f16e4a3f288204649a316cfb992b969ba5b190 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Oct 2018 16:17:26 +0800 Subject: [PATCH 10/14] fix typo bug --- .../mkldnn/mkldnn_quantized_fully_connected.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index f91dcf0d2010..1bef026ea207 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -26,8 +26,8 @@ namespace mxnet { namespace op { -namespace quantilizedfc { -enum QuantilizedfcOpResource {kTempSpace}; +namespace quantized_fullc { +enum QuantizedFullyConnectedOpResource {kTempSpace}; } struct QuantizedShiftKernel { @@ -54,7 +54,7 @@ struct QuantizedSumInitKernelWithBias { out[i] = bias[i] * float_for_one_bias_quant / float_for_one_out_quant; } else { - LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !"; + LOG(INFO) << "WARNING: QuantizedSumInitKernelWithBias float_for_one_out_quant is 0 !"; out[i] = 0; } } @@ -112,7 +112,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) int shift = 128; Tensor shiftdata = - ctx.requested[quantilizedfc::kTempSpace].get_space_typed( + ctx.requested[quantized_fullc::kTempSpace].get_space_typed( Shape1(m * k), s); Kernel::Launch(s, m * k, data.data().dptr(), shiftdata.dptr_, shift); From ae6c94039831e76dcba9874af44f81ebbae20b7c Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Mon, 22 Oct 2018 14:59:29 +0800 Subject: [PATCH 11/14] optimize omp for sum and copyoffset --- .../mkldnn_quantized_fully_connected.cc | 64 ++++++++----------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 1bef026ea207..a1b11c6f60e0 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -26,16 +26,10 @@ namespace mxnet { namespace op { -namespace quantized_fullc { -enum QuantizedFullyConnectedOpResource {kTempSpace}; +namespace quantilizedfc { +enum QuantilizedfcOpResource {kTempSpace}; } -struct QuantizedShiftKernel { - MSHADOW_XINLINE static void Map(int i, int8_t *in, uint8_t *out, int shift) { - out[i] = in[i] + shift; - } -}; - struct QuantizedSumInitKernelWithBias { // init sum data with bias for matrix b (n) MSHADOW_XINLINE static void Map(int i, int32_t *out, @@ -54,33 +48,12 @@ struct QuantizedSumInitKernelWithBias { out[i] = bias[i] * float_for_one_bias_quant / float_for_one_out_quant; } else { - LOG(INFO) << "WARNING: QuantizedSumInitKernelWithBias float_for_one_out_quant is 0 !"; + LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !"; out[i] = 0; } } }; -struct QuantizedSumInitKernel { - // init sum data for matrix b (n) - MSHADOW_XINLINE static void Map(int i, int32_t *out) { - out[i] = 0; - } -}; - -struct QuantizedSumKernel { - // get sum data(n) for matrix b (n * k) - MSHADOW_XINLINE static void Map(int i, size_t k, int8_t *in, int32_t *out, int shift) { - out[i / k] -= shift * in[i]; - } -}; - -struct QuantizedBetaCKernel { - // prepare beta C (from n to m * n) - MSHADOW_XINLINE static void Map(int i, size_t n, int32_t *out) { - out[i] = out[i % n]; - } -}; - template void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -99,7 +72,10 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, TShape dshape = data.shape(); TShape wshape = weight.shape(); TShape oshape = out.shape(); - + auto output_temp = out.data().dptr(); + auto weight_temp = weight.data().dptr(); + auto data_temp = data.data().dptr(); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); const float alpha = 1.0f; const float beta = 1.0f; const CBLAS_OFFSET offsetc = CblasFixOffset; @@ -112,10 +88,13 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) int shift = 128; Tensor shiftdata = - ctx.requested[quantized_fullc::kTempSpace].get_space_typed( + ctx.requested[quantilizedfc::kTempSpace].get_space_typed( Shape1(m * k), s); - Kernel::Launch(s, m * k, data.data().dptr(), - shiftdata.dptr_, shift); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * k; ++i) { + shiftdata.dptr_[i] = data_temp[i] + shift; + } + Kernel::Launch(s, 1, out_data[1].data().dptr(), out_data[2].data().dptr(), in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), @@ -127,13 +106,20 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, out_data[2].data().dptr(), in_data[7].data().dptr(), in_data[8].data().dptr()); } else { - Kernel::Launch(s, n, out.data().dptr()); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * n; ++i) { + output_temp[i] = 0; + } } - Kernel::Launch(s, n * k, k, weight.data().dptr(), - out.data().dptr(), shift); - - Kernel::Launch(s, m * n, n, out.data().dptr()); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < n * k; ++i) { + output_temp[i / k] -= shift * weight_temp[i]; + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = n; i < m * n; ++i) { + output_temp[i] = output_temp[i % n]; + } cblas_gemm_s8u8s32(CblasRowMajor, CblasNoTrans, CblasTrans, From 18d04bcd661c0b16595c73b9b74bd96f977a5f45 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Mon, 22 Oct 2018 15:58:56 +0800 Subject: [PATCH 12/14] optimize for sum --- .../mkldnn/mkldnn_quantized_fully_connected.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index a1b11c6f60e0..981e4e9038db 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -111,10 +111,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, output_temp[i] = 0; } } - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < n * k; ++i) { - output_temp[i / k] -= shift * weight_temp[i]; + for (int i = 0; i < n; ++i) { + for (int j = 0; j < k; ++j) { + output_temp[i] -= shift * weight_temp[i * k + j]; + } } #pragma omp parallel for num_threads(omp_threads) for (int i = n; i < m * n; ++i) { From e9e49c12ef22ce1a8142782442dc94fd68baf1b6 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Tue, 23 Oct 2018 10:44:13 +0800 Subject: [PATCH 13/14] add pre micro for s8u8 mklml check --- .../mkldnn/mkldnn_quantized_fully_connected.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 981e4e9038db..4137d1c1b89f 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -53,13 +53,14 @@ struct QuantizedSumInitKernelWithBias { } } }; - template void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { + #if MSHADOW_USE_MKL == 1 + // s8u8s32 implementation const FullyConnectedParam& param = nnvm::get(attrs.parsed); using namespace mshadow; using namespace mxnet_op; @@ -121,6 +122,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, for (int i = n; i < m * n; ++i) { output_temp[i] = output_temp[i % n]; } + cblas_gemm_s8u8s32(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -139,6 +141,9 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, out.data().dptr(), n, &oc); + #else + LOG(FATAL) << "s8u8s32 is not supported by the BLAS library"; + #endif } NNVM_REGISTER_OP(_contrib_quantized_fully_connected) From e8bf13fffaa0c73886af2a494b6fac52c5893b3d Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Tue, 23 Oct 2018 14:26:29 +0800 Subject: [PATCH 14/14] fix typo bug --- .../quantization/mkldnn/mkldnn_quantized_fully_connected.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 4137d1c1b89f..f1b772c7e3b0 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -142,7 +142,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, n, &oc); #else - LOG(FATAL) << "s8u8s32 is not supported by the BLAS library"; + LOG(FATAL) << "s8u8s32 is only supported by MKL BLAS library"; #endif }