Skip to content

Commit

Permalink
Don't requantize if bias or quantize scales are approximately equal (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored Dec 10, 2021
1 parent aa99699 commit 510f7c6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
20 changes: 16 additions & 4 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Relay functions for rewriting fake quantized ops."""
import numpy as np

import tvm
from tvm import relay
from tvm.ir import TensorAffineType, TupleAffineType
Expand All @@ -34,6 +36,16 @@ def infer_shape(expr):
return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape


def approx_equal(x, y):
x = fold_constant(x)
y = fold_constant(y)
if isinstance(x, relay.Constant) and isinstance(y, relay.Constant):
equal = np.allclose(x.data.asnumpy(), y.data.asnumpy())
else:
equal = tvm.ir.structural_equal(x, y)
return equal


@register_fake_quantization_to_integer("qnn.dequantize")
def dequantize(expr, type_map):
"""Remove dequantize op"""
Expand All @@ -50,8 +62,8 @@ def quantize(expr, type_map):
in_scale = fold_constant(t.scale)
in_zero_point = fold_constant(t.zero_point)
if not (
tvm.ir.structural_equal(in_scale, expr.args[1])
and tvm.ir.structural_equal(in_zero_point, expr.args[2])
approx_equal(in_scale, expr.args[1])
and approx_equal(in_zero_point, expr.args[2])
and tvm.ir.structural_equal(t.dtype, expr.attrs.out_dtype)
):
out = relay.qnn.op.requantize(
Expand Down Expand Up @@ -121,8 +133,8 @@ def bias_add(expr, type_map):
in_scale = fold_constant(x_t.scale)
in_zero_point = fold_constant(x_t.zero_point)
if not (
tvm.ir.structural_equal(x_t.scale, b_t.scale)
and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point)
approx_equal(x_t.scale, b_t.scale)
and approx_equal(x_t.zero_point, b_t.zero_point)
and tvm.ir.structural_equal(x_t.dtype, b_t.dtype)
):
b = relay.qnn.op.requantize(
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_pass_fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def test_fake_transpose_quantize_conv_bias_add_per_channel():
one = relay.const(1.0)
zero = relay.const(0)
w_scale = (np.random.random([16]).astype("float32") - 0.5) / 10 + 0.5
noise = (np.random.random([16]).astype("float32") - 0.5) * 1e-15
w_zp = relay.const([0] * 16)

x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
Expand All @@ -208,7 +209,7 @@ def test_fake_transpose_quantize_conv_bias_add_per_channel():
x, relay.qnn.op.dequantize(w, relay.const(w_scale), w_zp, axis=0), kernel_size=[5, 5]
)
op = relay.op.nn.bias_add(
op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale), w_zp, axis=0)
op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale + noise), w_zp, axis=0)
)
op = relay.qnn.op.quantize(op, one, zero)

Expand Down

0 comments on commit 510f7c6

Please sign in to comment.