Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Mxnet-1397] Support symbolic api for requantize and dequantize #14749

Merged
merged 7 commits into from
Apr 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ List of Contributors
* [Zhennan Qin](https://github.com/ZhennanQin)
* [Zhiyuan Huang](https://github.com/huangzhiyuan)
* [Zak Jost](https://github.com/zjost)
* [Shoubhik Bhattacharya](https://github.com/shoubhik)
* [Zach Kimberg](https://github.com/zachgk)

Label Bot
Expand Down
4 changes: 4 additions & 0 deletions src/operator/quantization/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ by keep zero centered for the quantized value:
.set_attr_parser(ParamParser<DequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "min_range", "max_range"};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these names will be exposed to front end users, I hope they can align with other quantization operators. In quantized convolution and quantized FC, I see they are min_data and max_data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the names are documented well in most of the quantized ops I think it should be ok. Especially in quantized conv and FC there are too many quantized parameters, I think it is easier to understand the API with min_data and max_data

})
.set_attr<mxnet::FInferShape>("FInferShape", DequantizeShape)
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/quantization/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ inference accuracy.
.set_attr_parser(ParamParser<RequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(3)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "min_range", "max_range"};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above.

})
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
.set_attr<nnvm::FInferType>("FInferType", RequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
Expand Down
82 changes: 71 additions & 11 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,45 @@ def test_quantize_float32_to_int8():

@with_seed()
def test_dequantize_int8_to_float32():

def get_test_data(real_range, qdata_np):
qdata = mx.nd.array(qdata_np, dtype=np.int8)
min_range = mx.nd.array([-real_range], dtype=np.float32)
max_range = mx.nd.array([real_range], dtype=np.float32)
return qdata, min_range, max_range

def baseline_dequantization(qdata, real_range, qdata_np):
quantized_range = 127.0
scale = real_range / quantized_range
data_np = qdata_np * scale
return data_np

def test_nd_array_dequantization(qdata, min_range, max_range, expected_result):
data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32')
assert data.dtype == np.float32
assert_almost_equal(data.asnumpy(), expected_result)

def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result):
sym_data = mx.sym.Variable('data')
sym_min_range = mx.sym.Variable('min_range')
sym_max_range = mx.sym.Variable('max_range')
dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range,
sym_max_range, out_type='float32')
out = dequant.bind(ctx=mx.current_context(),
args={'data':qdata, 'min_range':min_range, 'max_range':max_range})
data = out.forward()[0]
assert data.dtype == np.float32
assert_almost_equal(data.asnumpy(), expected_result)

real_range = 402.3347
shape = rand_shape_nd(4)
qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8)
qdata = mx.nd.array(qdata_np, dtype=np.int8)
real_range = 402.3347
min_range = mx.nd.array([-real_range], dtype=np.float32)
max_range = mx.nd.array([real_range], dtype=np.float32)
data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32')
quantized_range = 127.0
scale = real_range / quantized_range
assert data.dtype == np.float32
data_np = qdata_np * scale
assert_almost_equal(data.asnumpy(), data_np)

qdata, min_range, max_range = get_test_data(real_range, qdata_np)
expected_result = baseline_dequantization(qdata, real_range, qdata_np)
# test nd array implementation.
test_nd_array_dequantization(qdata, min_range, max_range, expected_result)
# test symbolic api implementaion.
test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result)

@with_seed()
def test_requantize_int32_to_int8():
Expand Down Expand Up @@ -124,7 +150,41 @@ def check_requantize(shape, min_calib_range=None, max_calib_range=None):
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1)
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))

def check_requantize_with_symbol(shape, min_calib_range=None, max_calib_range=None):
qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32')
min_range = mx.nd.array([-1010.0])
max_range = mx.nd.array([1020.0])
sym_data = mx.sym.Variable('data')
sym_min_range = mx.sym.Variable('min_range')
sym_max_range = mx.sym.Variable('max_range')
if min_calib_range is None or max_calib_range is None:
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range)
out = requant.bind(ctx=mx.current_context(),
args={'data':qdata, 'min_range':min_range,
'max_range':max_range})
qdata_int8, min_output, max_output = out.forward()
else:
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range,
min_calib_range, max_calib_range)
out = requant.bind(ctx=mx.current_context(), args={'data':qdata, 'min_range':min_range,
'max_range':max_range})
qdata_int8, min_output, max_output = out.forward()

qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(),
max_range.asscalar(),
min_calib_range=min_calib_range,
max_calib_range=max_calib_range)
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np)
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))

# test with symbol API.
check_requantize_with_symbol((3, 4, 10, 10))
check_requantize_with_symbol((32, 3, 23, 23))
check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0)
check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43)
# Test with nd array API
check_requantize((3, 4, 10, 10))
check_requantize((32, 3, 23, 23))
check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0)
Expand Down