From eb9608b87e3d32a7de2f0f9711ee808b3fcec75b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 10 Dec 2021 09:09:02 -0700 Subject: [PATCH] Don't requantize if bias or quantize scales are approximately equal (#9676) --- .../transform/fake_quantization_to_integer.py | 20 +++++++++++++++---- .../test_pass_fake_quantization_to_integer.py | 3 ++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 1adde9a4a4305..71dc9d9be99ea 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Relay functions for rewriting fake quantized ops.""" +import numpy as np + import tvm from tvm import relay from tvm.ir import TensorAffineType, TupleAffineType @@ -34,6 +36,16 @@ def infer_shape(expr): return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape +def approx_equal(x, y): + x = fold_constant(x) + y = fold_constant(y) + if isinstance(x, relay.Constant) and isinstance(y, relay.Constant): + equal = np.allclose(x.data.asnumpy(), y.data.asnumpy()) + else: + equal = tvm.ir.structural_equal(x, y) + return equal + + @register_fake_quantization_to_integer("qnn.dequantize") def dequantize(expr, type_map): """Remove dequantize op""" @@ -50,8 +62,8 @@ def quantize(expr, type_map): in_scale = fold_constant(t.scale) in_zero_point = fold_constant(t.zero_point) if not ( - tvm.ir.structural_equal(in_scale, expr.args[1]) - and tvm.ir.structural_equal(in_zero_point, expr.args[2]) + approx_equal(in_scale, expr.args[1]) + and approx_equal(in_zero_point, expr.args[2]) and tvm.ir.structural_equal(t.dtype, expr.attrs.out_dtype) ): out = relay.qnn.op.requantize( @@ -121,8 +133,8 @@ def bias_add(expr, type_map): in_scale = fold_constant(x_t.scale) in_zero_point = fold_constant(x_t.zero_point) if not ( - tvm.ir.structural_equal(x_t.scale, b_t.scale) - and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point) + approx_equal(x_t.scale, b_t.scale) + and approx_equal(x_t.zero_point, b_t.zero_point) and tvm.ir.structural_equal(x_t.dtype, b_t.dtype) ): b = relay.qnn.op.requantize( diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index c49d837ed9201..2889359138dff 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -200,6 +200,7 @@ def test_fake_transpose_quantize_conv_bias_add_per_channel(): one = relay.const(1.0) zero = relay.const(0) w_scale = (np.random.random([16]).astype("float32") - 0.5) / 10 + 0.5 + noise = (np.random.random([16]).astype("float32") - 0.5) * 1e-15 w_zp = relay.const([0] * 16) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) @@ -208,7 +209,7 @@ def test_fake_transpose_quantize_conv_bias_add_per_channel(): x, relay.qnn.op.dequantize(w, relay.const(w_scale), w_zp, axis=0), kernel_size=[5, 5] ) op = relay.op.nn.bias_add( - op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale), w_zp, axis=0) + op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale + noise), w_zp, axis=0) ) op = relay.qnn.op.quantize(op, one, zero)