From 08a6ee520fbcec8cb3190ac14929f7e88694f053 Mon Sep 17 00:00:00 2001 From: Venkat Rasagna Komatireddy <89959097+rasagna-quic@users.noreply.github.com> Date: Sat, 16 Sep 2023 00:03:44 +0530 Subject: [PATCH] [Hexagon] F2qi avgpool bug fix (#15599) F2qi avgpool bug fix --- .../transform/fake_quantization_to_integer.py | 31 ++ python/tvm/topi/hexagon/qnn/avg_pool2d.py | 22 +- .../contrib/test_hexagon/infrastructure.py | 22 ++ .../test_hexagon/test_pass_fq2i_avg_pool2d.py | 290 ++++++++++++++++++ .../test_relay_simplify_conv_pat.py | 23 +- 5 files changed, 354 insertions(+), 34 deletions(-) create mode 100644 tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 5e289b0c9380..b27fc3cba799 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -119,6 +119,37 @@ def avgpool2d(expr, type_map): t = type_map[arg] out_t = type_map[expr] + # dq > nn.avg_pool2d > q + # Use the same input quantization parameters for output if the pattern is not the above. + # Type_map is a map of graphs and their Tensoraffinetypes + # Find the current "nn.avg_pool2d" op after checking for the "qnn.quantize" op in the graph. + # Structure for .. dq > op > q will be q [op [dq .. + def check(y, expr): + if isinstance(y, type(expr)): + if y.op.name != "nn.avg_pool2d": + return True + # check if this is the expr avg_pool + if y.attrs != expr.attrs: + return True + return False + + for x in type_map.items(): + if isinstance(x[0], type(expr)): + if x[0].op.name == "qnn.quantize": + prev = x[0] + y = prev.args[0] + while check(y, expr): + prev = y + y = prev.args[0] + if ( + isinstance(y, type(expr)) + and y.op.name == "nn.avg_pool2d" + and y.attrs == expr.attrs + ): + if prev.op.name != "qnn.quantize": + out_t = t + break + out = relay.qnn.op.avg_pool2d( arg, t.scale, diff --git a/python/tvm/topi/hexagon/qnn/avg_pool2d.py b/python/tvm/topi/hexagon/qnn/avg_pool2d.py index 1370ad36e468..4e88f39b0552 100644 --- a/python/tvm/topi/hexagon/qnn/avg_pool2d.py +++ b/python/tvm/topi/hexagon/qnn/avg_pool2d.py @@ -39,6 +39,14 @@ def saturate(x: te.Tensor, dtype: str): return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) +def get_temp_dtype(h, w, dtype): + temp_dtype = "int16" if h * w < 256 else "int32" + if dtype in ("uint8", "int8"): + return temp_dtype + else: + raise RuntimeError(f"Unsupported output dtype, {odtype}'") + + def qnn_avg_pool2d_NCHW( data: te.Tensor, kernel: list, @@ -59,12 +67,7 @@ def qnn_avg_pool2d_NCHW( rh = te.reduce_axis((0, kh), name="rh") rw = te.reduce_axis((0, kw), name="rw") - if odtype == "uint8": - temp_dtype = "uint16" - elif odtype == "int8": - temp_dtype = "int16" - else: - raise RuntimeError(f"Unsupported output dtype, {odtype}'") + temp_dtype = get_temp_dtype(kh, kw, odtype) sh, sw = stride dh, dw = dilation @@ -155,12 +158,7 @@ def qnn_avg_pool2d_NHWC( rh = te.reduce_axis((0, kh), name="rh") rw = te.reduce_axis((0, kw), name="rw") - if odtype == "uint8": - temp_dtype = "uint16" - elif odtype == "int8": - temp_dtype = "int16" - else: - raise RuntimeError(f"Unsupported output dtype, {odtype}'") + temp_dtype = get_temp_dtype(kh, kw, odtype) sh, sw = stride dh, dw = dilation diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 735b3f2b94b5..c728ae842cf1 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -21,6 +21,7 @@ import numpy import tvm from tvm import te +from tvm.relay.backend import Executor def ceildiv(o, d): @@ -112,6 +113,27 @@ def build_and_run(inputs, func, target: str, target_host: str, *args, **kwargs): return tensors[-1].asnumpy() +def build_module(relay_mod, target): + """builds a relay module for a specified target""" + params = {} + executor = Executor("aot", {"link-params": True}) + lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target, host=target), + executor=executor, + params=params, + ) + return lowered + + +def run_module(mod, inputs): + """invokes run function of specified module with inputs provided""" + mod.set_input(**inputs) + mod.run() + output = mod.get_output(0).numpy() + return output + + def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, out_channels): assert len(shape_nhwc) == 4 kernel = [] diff --git a/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py b/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py new file mode 100644 index 000000000000..34f356a0158c --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py @@ -0,0 +1,290 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Tests for avg_pool2d fake quantization to integer """ + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib.hexagon.session import Session +from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET +from .infrastructure import quantize_np, build_module, run_module + + +def compare_graphs(expr, ref_expr): + """Compares the given graph with the expected graph""" + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + ref_mod = tvm.IRModule.from_expr(ref_expr) + ref_mod = tvm.relay.transform.InferType()(ref_mod) + assert tvm.ir.structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True) + + +def compare_fq_to_int(hexagon_session, expr, inputs): + """Compares the float module output with the integer module output""" + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod_int) + + mod = build_module( + mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) + ) + mod_int = build_module( + mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) + ) + + hexagon_mod = hexagon_session.get_executor_from_factory(mod) + result = run_module(hexagon_mod, inputs) + + hexagon_mod = hexagon_session.get_executor_from_factory(mod_int) + result_int = run_module(hexagon_mod, inputs) + + tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02) + + +@tvm.testing.requires_hexagon +def test_avgpool_conv2d(hexagon_session: Session): + """Test case with avg_pool2d followed by a conv2d""" + dtype = "int8" + shape_x = [1, 2, 9, 9] + shape_w = [1, 2, 3, 3] + kernel = [3, 3] + stride = [1, 1] + dilation = [1, 1] + inp = relay.var("input", shape=shape_x, dtype=dtype) + wgt = relay.var("weight", shape=shape_w, dtype=dtype) + + x_np = np.random.random(shape_x) + w_np = np.random.random(shape_w) + + fp_avg = tvm.topi.testing.poolnd_python( + x_np, + kernel, + stride, + dilation, + padding_before=[0, 0], + padding_after=[0, 0], + pool_type="avg", + ) + fp_output = tvm.topi.testing.conv2d_nchw_python( + fp_avg, + w_np, + [1, 1], + [0, 0], + ) + + # Computing quantization parameters + input_quant, input_scale, input_zero_point = quantize_np(x_np, dtype) + weight_quant, weight_scale, weight_zero_point = quantize_np(w_np, dtype) + _, output_scale, output_zero_point = quantize_np(fp_output, dtype) + + inp_zp = relay.const(input_zero_point) + inp_sc = relay.const(input_scale) + wgt_zp = relay.const(weight_zero_point) + wgt_sc = relay.const(weight_scale) + out_zp = relay.const(output_zero_point) + out_sc = relay.const(output_scale) + + # Tested expression. + op0 = relay.qnn.op.dequantize(inp, inp_sc, inp_zp) + op1 = relay.op.nn.avg_pool2d(op0, kernel) + op2 = relay.qnn.op.dequantize(wgt, wgt_sc, wgt_zp) + op3 = relay.op.nn.conv2d(op1, op2, kernel_size=kernel) + expr = relay.qnn.op.quantize(op3, out_sc, out_zp, out_dtype=dtype) + expr = relay.qnn.op.dequantize(expr, out_sc, out_zp) + args = {"input": input_quant, "weight": weight_quant} + + compare_fq_to_int(hexagon_session, expr, args) + + # Expected graph + op0 = relay.qnn.op.avg_pool2d( + inp, + input_scale=inp_sc, + input_zero_point=inp_zp, + output_scale=inp_sc, + output_zero_point=inp_zp, + pool_size=kernel, + strides=stride, + dilation=dilation, + padding=[0, 0, 0, 0], + layout="NCHW", + count_include_pad=False, + ) + op1 = relay.qnn.op.conv2d( + op0, + wgt, + input_scale=inp_sc, + input_zero_point=inp_zp, + kernel_scale=wgt_sc, + kernel_zero_point=wgt_zp, + kernel_size=kernel, + channels=None, + ) + op2 = relay.qnn.op.requantize( + op1, + input_scale=relay.const(input_scale * weight_scale), + input_zero_point=relay.const(0), + output_scale=out_sc, + output_zero_point=out_zp, + axis=1, + out_dtype="int8", + ) + ref_expr = relay.qnn.op.dequantize(op2, out_sc, out_zp) + compare_graphs(expr, ref_expr) + + +@tvm.testing.requires_hexagon +def test_avgpool_avgpool(hexagon_session: Session): + """Test case with avg_pool2d followed by an avg_pool2d""" + dtype = "uint8" + shape_x = [1, 2, 9, 9] + kernel = [3, 3] + stride = [1, 1] + dilation = [1, 1] + inp = relay.var("input", shape=shape_x, dtype=dtype) + x_np = np.random.random(shape_x) + + fp_avg = tvm.topi.testing.poolnd_python( + x_np, + kernel, + stride, + dilation, + padding_before=[0, 0], + padding_after=[0, 0], + pool_type="avg", + ) + fp_output = tvm.topi.testing.poolnd_python( + fp_avg, + kernel, + stride, + dilation, + padding_before=[0, 0], + padding_after=[0, 0], + pool_type="avg", + ) + + # Computing quantization parameters + input_quant, input_scale, input_zero_point = quantize_np(x_np, dtype) + _, output_scale, output_zero_point = quantize_np(fp_output, dtype) + + inp_zp = relay.const(input_zero_point) + inp_sc = relay.const(input_scale) + out_zp = relay.const(output_zero_point) + out_sc = relay.const(output_scale) + + # Tested expression. + op0 = relay.qnn.op.dequantize(inp, inp_sc, inp_zp) + op1 = relay.op.nn.avg_pool2d(op0, kernel) + op2 = relay.op.nn.avg_pool2d(op1, kernel) + expr = relay.qnn.op.quantize(op2, out_sc, out_zp, out_dtype=dtype) + expr = relay.qnn.op.dequantize(expr, out_sc, out_zp) + args = {"input": input_quant} + compare_fq_to_int(hexagon_session, expr, args) + + # Expected graph + op0 = relay.qnn.op.avg_pool2d( + inp, + input_scale=inp_sc, + input_zero_point=inp_zp, + output_scale=inp_sc, + output_zero_point=inp_zp, + pool_size=kernel, + strides=stride, + dilation=dilation, + padding=[0, 0, 0, 0], + layout="NCHW", + count_include_pad=False, + ) + op1 = relay.qnn.op.avg_pool2d( + op0, + input_scale=inp_sc, + input_zero_point=inp_zp, + output_scale=out_sc, + output_zero_point=out_zp, + pool_size=kernel, + strides=stride, + dilation=dilation, + padding=[0, 0, 0, 0], + layout="NCHW", + count_include_pad=False, + ) + ref_expr = relay.qnn.op.dequantize(op1, out_sc, out_zp) + compare_graphs(expr, ref_expr) + + +@tvm.testing.requires_hexagon +def test_avgpool(hexagon_session: Session): + """Test case of a regular avg_pool2d""" + dtype = "int8" + shape_x = [1, 2, 9, 9] + kernel = [3, 3] + stride = [1, 1] + dilation = [1, 1] + inp = relay.var("input", shape=shape_x, dtype=dtype) + x_np = np.random.random(shape_x) + + fp_output = tvm.topi.testing.poolnd_python( + x_np, + kernel, + stride, + dilation, + padding_before=[0, 0], + padding_after=[0, 0], + pool_type="avg", + ) + + # Computing quantization parameters + input_quant, input_scale, input_zero_point = quantize_np(x_np, dtype) + _, output_scale, output_zero_point = quantize_np(fp_output, dtype) + + inp_zp = relay.const(input_zero_point) + inp_sc = relay.const(input_scale) + out_zp = relay.const(output_zero_point) + out_sc = relay.const(output_scale) + + # Tested expression + op0 = relay.qnn.op.dequantize(inp, inp_sc, inp_zp) + op1 = relay.op.nn.avg_pool2d(op0, kernel) + expr = relay.qnn.op.quantize(op1, out_sc, out_zp, out_dtype=dtype) + expr = relay.qnn.op.dequantize(expr, out_sc, out_zp) + args = {"input": input_quant} + compare_fq_to_int(hexagon_session, expr, args) + + # Expected graph + op = relay.qnn.op.avg_pool2d( + inp, + input_scale=inp_sc, + input_zero_point=inp_zp, + output_scale=out_sc, + output_zero_point=out_zp, + pool_size=kernel, + strides=stride, + dilation=dilation, + padding=[0, 0, 0, 0], + layout="NCHW", + count_include_pad=False, + ) + ref_expr = relay.qnn.op.dequantize(op, out_sc, out_zp) + compare_graphs(expr, ref_expr) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py b/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py index a85762cc24be..b2c60b083cc1 100644 --- a/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py +++ b/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py @@ -22,12 +22,12 @@ import numpy as np import tvm from tvm.runtime import ndarray as nd -from tvm.relay.backend import Executor from tvm import relay, testing from tvm.contrib.hexagon.transform import simplify_conv_pat from tvm.topi.utils import get_const_tuple from tvm.contrib.hexagon.session import Session from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET +from .infrastructure import build_module, run_module def get_test_module_relay_exprs(isConstScalarMultiplier=True): @@ -123,27 +123,6 @@ def get_expected_output_module( return tvm.IRModule.from_expr(out_func) -def build_module(relay_mod, target): - """builds a relay module for a specified target""" - params = {} - executor = Executor("aot", {"link-params": True}) - lowered = tvm.relay.build( - relay_mod, - tvm.target.Target(target, host=target), - executor=executor, - params=params, - ) - return lowered - - -def run_module(mod, inputs): - """invokes run function of specified module with inputs provided""" - mod.set_input(**inputs) - mod.run() - output = mod.get_output(0).numpy() - return output - - def get_test_modules(): """generates test, expected modules and their inputs""" (