Skip to content

Commit

Permalink
[Mxnet-1397] Support symbolic api for requantize and dequantize (apac…
Browse files Browse the repository at this point in the history
…he#14749)

* Adding support for symbolic API for requantize and dequantize

* Adding name to contributors list

* Removing redundant code

* Addressing indentation and using current_context() instead of cpu()

* merge from master

* merge from master
  • Loading branch information
shoubhik authored and haohuw committed Jun 23, 2019
1 parent 17a3bc0 commit f773295
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 11 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ List of Contributors
* [Zhennan Qin](https://github.com/ZhennanQin)
* [Zhiyuan Huang](https://github.com/huangzhiyuan)
* [Zak Jost](https://github.com/zjost)
* [Shoubhik Bhattacharya](https://github.com/shoubhik)
* [Zach Kimberg](https://github.com/zachgk)

Label Bot
Expand Down
4 changes: 4 additions & 0 deletions src/operator/quantization/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ by keep zero centered for the quantized value:
.set_attr_parser(ParamParser<DequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "min_range", "max_range"};
})
.set_attr<mxnet::FInferShape>("FInferShape", DequantizeShape)
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/quantization/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ inference accuracy.
.set_attr_parser(ParamParser<RequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(3)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "min_range", "max_range"};
})
.set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
.set_attr<nnvm::FInferType>("FInferType", RequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
Expand Down
82 changes: 71 additions & 11 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,45 @@ def test_quantize_float32_to_int8():

@with_seed()
def test_dequantize_int8_to_float32():

def get_test_data(real_range, qdata_np):
qdata = mx.nd.array(qdata_np, dtype=np.int8)
min_range = mx.nd.array([-real_range], dtype=np.float32)
max_range = mx.nd.array([real_range], dtype=np.float32)
return qdata, min_range, max_range

def baseline_dequantization(qdata, real_range, qdata_np):
quantized_range = 127.0
scale = real_range / quantized_range
data_np = qdata_np * scale
return data_np

def test_nd_array_dequantization(qdata, min_range, max_range, expected_result):
data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32')
assert data.dtype == np.float32
assert_almost_equal(data.asnumpy(), expected_result)

def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result):
sym_data = mx.sym.Variable('data')
sym_min_range = mx.sym.Variable('min_range')
sym_max_range = mx.sym.Variable('max_range')
dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range,
sym_max_range, out_type='float32')
out = dequant.bind(ctx=mx.current_context(),
args={'data':qdata, 'min_range':min_range, 'max_range':max_range})
data = out.forward()[0]
assert data.dtype == np.float32
assert_almost_equal(data.asnumpy(), expected_result)

real_range = 402.3347
shape = rand_shape_nd(4)
qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8)
qdata = mx.nd.array(qdata_np, dtype=np.int8)
real_range = 402.3347
min_range = mx.nd.array([-real_range], dtype=np.float32)
max_range = mx.nd.array([real_range], dtype=np.float32)
data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32')
quantized_range = 127.0
scale = real_range / quantized_range
assert data.dtype == np.float32
data_np = qdata_np * scale
assert_almost_equal(data.asnumpy(), data_np)

qdata, min_range, max_range = get_test_data(real_range, qdata_np)
expected_result = baseline_dequantization(qdata, real_range, qdata_np)
# test nd array implementation.
test_nd_array_dequantization(qdata, min_range, max_range, expected_result)
# test symbolic api implementaion.
test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result)

@with_seed()
def test_requantize_int32_to_int8():
Expand Down Expand Up @@ -124,7 +150,41 @@ def check_requantize(shape, min_calib_range=None, max_calib_range=None):
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1)
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))

def check_requantize_with_symbol(shape, min_calib_range=None, max_calib_range=None):
qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32')
min_range = mx.nd.array([-1010.0])
max_range = mx.nd.array([1020.0])
sym_data = mx.sym.Variable('data')
sym_min_range = mx.sym.Variable('min_range')
sym_max_range = mx.sym.Variable('max_range')
if min_calib_range is None or max_calib_range is None:
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range)
out = requant.bind(ctx=mx.current_context(),
args={'data':qdata, 'min_range':min_range,
'max_range':max_range})
qdata_int8, min_output, max_output = out.forward()
else:
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range,
min_calib_range, max_calib_range)
out = requant.bind(ctx=mx.current_context(), args={'data':qdata, 'min_range':min_range,
'max_range':max_range})
qdata_int8, min_output, max_output = out.forward()

qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(),
max_range.asscalar(),
min_calib_range=min_calib_range,
max_calib_range=max_calib_range)
assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np)
assert_almost_equal(min_output.asnumpy(), np.array([min_output_np]))
assert_almost_equal(max_output.asnumpy(), np.array([max_output_np]))

# test with symbol API.
check_requantize_with_symbol((3, 4, 10, 10))
check_requantize_with_symbol((32, 3, 23, 23))
check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0)
check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43)
# Test with nd array API
check_requantize((3, 4, 10, 10))
check_requantize((32, 3, 23, 23))
check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0)
Expand Down

0 comments on commit f773295

Please sign in to comment.