Skip to content

Commit

Permalink
support quantilized fc in cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Li, Hao H committed Oct 16, 2018
1 parent b89a36d commit 1c20a87
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 12 deletions.
134 changes: 134 additions & 0 deletions src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#if MXNET_USE_MKLDNN == 1

#include "../../nn/mkldnn/mkldnn_base-inl.h"
#include "../quantization_utils.h"
#include "../../nn/fully_connected-inl.h"

namespace mxnet {
namespace op {

// value + bias_value * (range1 / limit_range1) * (limit_range2 / range2)
struct QuantizedBiasAddKernel {

This comment has been minimized.

MSHADOW_XINLINE static void Map(int i, size_t k, int32_t *out,
const int8_t *bias, const float *min_out,
const float *max_out, const float *min_bias,
const float *max_bias) {
typedef int32_t T1;
typedef int8_t T2;
using mshadow::red::limits::MinValue;
using mshadow::red::limits::MaxValue;
float float_for_one_out_quant =
MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
float float_for_one_bias_quant =
MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
if (float_for_one_out_quant != 0) {
out[i] = (out[i] * float_for_one_out_quant +
bias[i%k] * float_for_one_bias_quant) /
float_for_one_out_quant;
} else {
LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !";
}
}
};

template<typename SrcType, typename DstType, typename CmpType>

This comment has been minimized.

Copy link
@ciyongch

ciyongch Oct 16, 2018

Looks like DstType and CmpType is useless here.

void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,

This comment has been minimized.

Copy link
@TaoLv

TaoLv Oct 16, 2018

fix indent.

This comment has been minimized.

Copy link
@lihaofd

lihaofd Oct 16, 2018

Owner

fixed

const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
using namespace mshadow;
using namespace mxnet_op;
size_t num_inputs = param.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), num_inputs * 3);
CHECK_EQ(out_data.size(), 3U);
const NDArray& data = in_data[0];
const NDArray& weight = in_data[1];
const NDArray& out = out_data[0];
TShape dshape = data.shape();
TShape wshape = weight.shape();
TShape oshape = out.shape();

CHECK(in_data[0].dtype() == mshadow::kInt8
&& in_data[1].dtype() == mshadow::kInt8)
<< "mkldnn_quantized_FullyConnected op only supports int8 as input type";

This comment has been minimized.

Copy link
@TaoLv

TaoLv Oct 16, 2018

why?


const float alpha = 1.0f;
const float beta = 0.0f;
const CBLAS_OFFSET offsetc = CblasFixOffset;
const MKL_INT8 oa = -128;
const MKL_INT8 ob = 0;
MKL_INT32 oc = 0;
const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim());
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
uint8_t* pDataNewRange = reinterpret_cast<uint8_t*>(malloc(m*k*sizeof(uint8_t)));

This comment has been minimized.

Copy link
@TaoLv

TaoLv Oct 16, 2018

register resource for this buffer to avoid malloc/delete in each run.

This comment has been minimized.

Copy link
@lihaofd

lihaofd Oct 16, 2018

Owner

fixed


#pragma omp parallel for num_threads(omp_threads)

This comment has been minimized.

Copy link
@ciyongch

ciyongch Oct 16, 2018

a comment here for why shift the data is better.

for (int i = 0; i < m * k; i++) {
pDataNewRange[i] = data.data().dptr<int8_t>()[i] + 128;

This comment has been minimized.

Copy link
@TaoLv

TaoLv Oct 16, 2018

what's 128? looks like a magic number.

}

cblas_gemm_s8u8s32(CblasRowMajor,
CblasNoTrans,
CblasTrans,
offsetc,
m,
n,
k,
alpha,
pDataNewRange,
k,
oa,
weight.data().dptr<int8_t>(),
k,
ob,
beta,
out.data().dptr<int32_t>(),
n,
&oc);

free(pDataNewRange);
Stream<cpu> *s = ctx.get_stream<cpu>();
Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
in_data[num_inputs].data().dptr<float>(), in_data[num_inputs+1].data().dptr<float>(),
in_data[num_inputs+2].data().dptr<float>(), in_data[num_inputs+3].data().dptr<float>());

if (!param.no_bias) {

This comment has been minimized.

Copy link
@TaoLv

TaoLv Oct 16, 2018

Any possibility to put bias into gemm call?

const NDArray& bias = in_data[2];
Kernel<QuantizedBiasAddKernel, cpu>::Launch(s, out.shape().Size(),

This comment has been minimized.

Copy link
@ciyongch

ciyongch Oct 16, 2018

Do you measure the perf with and without Kernel Launch API implementation?

n, out.data().dptr<int32_t>(), bias.data().dptr<int8_t>(),
out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
in_data[7].data().dptr<float>(), in_data[8].data().dptr<float>());
}
}

NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
.set_attr<FComputeEx>("FComputeEx<cpu>",
MKLDNNQuantizedFullyConnectedForward<int8_t, int32_t, int32_t>);


} // namespace op
} // namespace mxnet
#endif

17 changes: 17 additions & 0 deletions src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
return true;
}

bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
*dispatch_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
}
#endif
for (size_t i = 0; i < out_attrs->size(); i++)
(*out_attrs)[i] = kDefaultStorage;
return true;
}

NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
.describe(R"code(Fully Connected operator for input, weight and bias data type of int8,
and accumulates in type int32 for the output. For each argument, two more arguments of type
Expand Down Expand Up @@ -112,6 +128,7 @@ and max thresholds representing the threholds for quantizing the float32 output
})
.set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.add_argument("data", "NDArray-or-Symbol", "Input data.")
.add_argument("weight", "NDArray-or-Symbol", "weight.")
Expand Down
25 changes: 13 additions & 12 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
@with_seed()
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
if mx.current_context().device_type != 'gpu':
print('skipped testing quantized_fc on cpu since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
if qdtype == 'uint8' and is_test_for_gpu():

This comment has been minimized.

Copy link
@ciyongch

ciyongch Oct 16, 2018

do we support uint8?

print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
return

Expand All @@ -283,17 +280,17 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
if qdtype == 'uint8':
data_low = 0.0
data_high = 127.0
data_high = 63.0
else:
data_low = -127.0
data_high = 127.0
data_low = -63.0
data_high = 63.0
fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=data_shape).astype('int32')
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
shape=arg_shapes[1]).astype('int32')
shape=data_shape).astype('int8')
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[1]).astype('int8')
if not no_bias:
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
shape=arg_shapes[2]).astype('int32')
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
shape=arg_shapes[2]).astype('int8')
output = fc_fp32_exe.forward()[0]

qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
Expand Down Expand Up @@ -335,6 +332,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)

@with_seed()
def test_quantized_flatten():
Expand Down

0 comments on commit 1c20a87

Please sign in to comment.