diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc new file mode 100644 index 000000000000..f1b772c7e3b0 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../quantization_utils.h" +#include "../../nn/fully_connected-inl.h" + +namespace mxnet { +namespace op { + +namespace quantilizedfc { +enum QuantilizedfcOpResource {kTempSpace}; +} + +struct QuantizedSumInitKernelWithBias { + // init sum data with bias for matrix b (n) + MSHADOW_XINLINE static void Map(int i, int32_t *out, + const int8_t *bias, const float *min_out, + const float *max_out, const float *min_bias, + const float *max_bias) { + typedef int32_t T1; + typedef int8_t T2; + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float float_for_one_out_quant = + MaxAbs(*min_out, *max_out) / static_cast(MaxValue()); + float float_for_one_bias_quant = + MaxAbs(*min_bias, *max_bias) / static_cast(MaxValue()); + if (float_for_one_out_quant != 0) { + out[i] = bias[i] * float_for_one_bias_quant / + float_for_one_out_quant; + } else { + LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !"; + out[i] = 0; + } + } +}; +template +void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + #if MSHADOW_USE_MKL == 1 + // s8u8s32 implementation + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + using namespace mxnet_op; + size_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(in_data.size(), num_inputs * 3); + CHECK_EQ(out_data.size(), 3U); + const NDArray& data = in_data[0]; + const NDArray& weight = in_data[1]; + const NDArray& out = out_data[0]; + TShape dshape = data.shape(); + TShape wshape = weight.shape(); + TShape oshape = out.shape(); + auto output_temp = out.data().dptr(); + auto weight_temp = weight.data().dptr(); + auto data_temp = data.data().dptr(); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const float alpha = 1.0f; + const float beta = 1.0f; + const CBLAS_OFFSET offsetc = CblasFixOffset; + const MKL_INT8 oa = 0; + const MKL_INT8 ob = 0; + MKL_INT32 oc = 0; + const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); + Stream *s = ctx.get_stream(); + // cblas_gemm_s8u8s32 required first matrix must be uint8 + // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) + int shift = 128; + Tensor shiftdata = + ctx.requested[quantilizedfc::kTempSpace].get_space_typed( + Shape1(m * k), s); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * k; ++i) { + shiftdata.dptr_[i] = data_temp[i] + shift; + } + + Kernel::Launch(s, 1, + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), + in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); + if (!param.no_bias) { + const NDArray& bias = in_data[2]; + Kernel::Launch(s, n, out.data().dptr(), + bias.data().dptr(), out_data[1].data().dptr(), + out_data[2].data().dptr(), in_data[7].data().dptr(), + in_data[8].data().dptr()); + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * n; ++i) { + output_temp[i] = 0; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < n; ++i) { + for (int j = 0; j < k; ++j) { + output_temp[i] -= shift * weight_temp[i * k + j]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = n; i < m * n; ++i) { + output_temp[i] = output_temp[i % n]; + } + + cblas_gemm_s8u8s32(CblasRowMajor, + CblasNoTrans, + CblasTrans, + offsetc, + m, + n, + k, + alpha, + shiftdata.dptr_, + k, + oa, + weight.data().dptr(), + k, + ob, + beta, + out.data().dptr(), + n, + &oc); + #else + LOG(FATAL) << "s8u8s32 is only supported by MKL BLAS library"; + #endif +} + +NNVM_REGISTER_OP(_contrib_quantized_fully_connected) +.set_attr("FComputeEx", + MKLDNNQuantizedFullyConnectedForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }); + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index e334fe7ec9b2..d596fd5bdcb0 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -79,6 +79,22 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, return true; } +bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + for (size_t i = 0; i < out_attrs->size(); i++) + (*out_attrs)[i] = kDefaultStorage; + return true; +} + NNVM_REGISTER_OP(_contrib_quantized_fully_connected) .describe(R"code(Fully Connected operator for input, weight and bias data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type @@ -112,6 +128,7 @@ and max thresholds representing the threholds for quantizing the float32 output }) .set_attr("FInferShape", QuantizedFullyConnectedShape) .set_attr("FInferType", QuantizedFullyConnectedType) +.set_attr("FInferStorageType", QuantizedFullyConnectedStorageType) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "weight.") diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 5ae2c6c398e9..d85cf2c4eeed 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -269,10 +269,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p @with_seed() def test_quantized_fc(): def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): - if mx.current_context().device_type != 'gpu': - print('skipped testing quantized_fc on cpu since it is not supported yet') - return - elif qdtype == 'uint8' and is_test_for_gpu(): + if qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') return @@ -283,17 +280,17 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') if qdtype == 'uint8': data_low = 0.0 - data_high = 127.0 + data_high = 63.0 else: - data_low = -127.0 - data_high = 127.0 + data_low = -63.0 + data_high = 63.0 fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, - shape=data_shape).astype('int32') - fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, - shape=arg_shapes[1]).astype('int32') + shape=data_shape).astype('int8') + fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[1]).astype('int8') if not no_bias: - fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, - shape=arg_shapes[2]).astype('int32') + fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[2]).astype('int8') output = fc_fp32_exe.forward()[0] qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8') @@ -335,6 +332,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): check_quantized_fc((32, 111, 2, 2), 100, True, qdtype) check_quantized_fc((32, 512, 2, 2), 100, False, qdtype) check_quantized_fc((32, 111, 2, 2), 100, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, True, qdtype) @with_seed() def test_quantized_flatten():