diff --git a/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py b/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py index 34f356a0158c..e45f56ba171c 100644 --- a/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py +++ b/tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py @@ -15,53 +15,24 @@ # specific language governing permissions and limitations # under the License. +# pylint: disable=redefined-outer-name + """ Tests for avg_pool2d fake quantization to integer """ import numpy as np +import pytest + import tvm import tvm.testing import tvm.topi.testing from tvm import relay from tvm.contrib.hexagon.session import Session from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET -from .infrastructure import quantize_np, build_module, run_module - - -def compare_graphs(expr, ref_expr): - """Compares the given graph with the expected graph""" - mod = tvm.IRModule.from_expr(expr) - mod = tvm.relay.transform.InferType()(mod) - mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) - ref_mod = tvm.IRModule.from_expr(ref_expr) - ref_mod = tvm.relay.transform.InferType()(ref_mod) - assert tvm.ir.structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True) - - -def compare_fq_to_int(hexagon_session, expr, inputs): - """Compares the float module output with the integer module output""" - mod = tvm.IRModule.from_expr(expr) - mod = tvm.relay.transform.InferType()(mod) - mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod_int) - - mod = build_module( - mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) - ) - mod_int = build_module( - mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) - ) - - hexagon_mod = hexagon_session.get_executor_from_factory(mod) - result = run_module(hexagon_mod, inputs) - - hexagon_mod = hexagon_session.get_executor_from_factory(mod_int) - result_int = run_module(hexagon_mod, inputs) - tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02) +from .infrastructure import quantize_np, build_module, run_module -@tvm.testing.requires_hexagon -def test_avgpool_conv2d(hexagon_session: Session): +def _make_avgpool_conv2d(): """Test case with avg_pool2d followed by a conv2d""" dtype = "int8" shape_x = [1, 2, 9, 9] @@ -112,8 +83,6 @@ def test_avgpool_conv2d(hexagon_session: Session): expr = relay.qnn.op.dequantize(expr, out_sc, out_zp) args = {"input": input_quant, "weight": weight_quant} - compare_fq_to_int(hexagon_session, expr, args) - # Expected graph op0 = relay.qnn.op.avg_pool2d( inp, @@ -148,11 +117,11 @@ def test_avgpool_conv2d(hexagon_session: Session): out_dtype="int8", ) ref_expr = relay.qnn.op.dequantize(op2, out_sc, out_zp) - compare_graphs(expr, ref_expr) + return expr, args, ref_expr -@tvm.testing.requires_hexagon -def test_avgpool_avgpool(hexagon_session: Session): + +def _make_avgpool_avgpool(): """Test case with avg_pool2d followed by an avg_pool2d""" dtype = "uint8" shape_x = [1, 2, 9, 9] @@ -197,7 +166,6 @@ def test_avgpool_avgpool(hexagon_session: Session): expr = relay.qnn.op.quantize(op2, out_sc, out_zp, out_dtype=dtype) expr = relay.qnn.op.dequantize(expr, out_sc, out_zp) args = {"input": input_quant} - compare_fq_to_int(hexagon_session, expr, args) # Expected graph op0 = relay.qnn.op.avg_pool2d( @@ -227,12 +195,11 @@ def test_avgpool_avgpool(hexagon_session: Session): count_include_pad=False, ) ref_expr = relay.qnn.op.dequantize(op1, out_sc, out_zp) - compare_graphs(expr, ref_expr) + return expr, args, ref_expr -@tvm.testing.requires_hexagon -def test_avgpool(hexagon_session: Session): - """Test case of a regular avg_pool2d""" + +def _make_avgpool(): dtype = "int8" shape_x = [1, 2, 9, 9] kernel = [3, 3] @@ -266,7 +233,6 @@ def test_avgpool(hexagon_session: Session): expr = relay.qnn.op.quantize(op1, out_sc, out_zp, out_dtype=dtype) expr = relay.qnn.op.dequantize(expr, out_sc, out_zp) args = {"input": input_quant} - compare_fq_to_int(hexagon_session, expr, args) # Expected graph op = relay.qnn.op.avg_pool2d( @@ -283,6 +249,63 @@ def test_avgpool(hexagon_session: Session): count_include_pad=False, ) ref_expr = relay.qnn.op.dequantize(op, out_sc, out_zp) + + return expr, args, ref_expr + + +def compare_graphs(expr, ref_expr): + """Compares the given graph with the expected graph""" + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + ref_mod = tvm.IRModule.from_expr(ref_expr) + ref_mod = tvm.relay.transform.InferType()(ref_mod) + tvm.ir.assert_structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True) + + +def compare_fq_to_int(hexagon_session, expr, inputs): + """Compares the float module output with the integer module output""" + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod_int) + + mod = build_module( + mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) + ) + mod_int = build_module( + mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) + ) + + hexagon_mod = hexagon_session.get_executor_from_factory(mod) + result = run_module(hexagon_mod, inputs) + + hexagon_mod = hexagon_session.get_executor_from_factory(mod_int) + result_int = run_module(hexagon_mod, inputs) + + tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02) + + +avgpool_test_case = tvm.testing.parameter( + _make_avgpool, + _make_avgpool_avgpool, + pytest.param( + _make_avgpool_conv2d, + marks=pytest.mark.xfail( + reason="Rounding differences causing mismatch of Constant, difference around 10^-7" + ), + ), +) + + +@tvm.testing.requires_hexagon +def test_execution(hexagon_session: Session, avgpool_test_case): + expr, args, _ = avgpool_test_case() + compare_fq_to_int(hexagon_session, expr, args) + + +def test_quantization(avgpool_test_case): + expr, _, ref_expr = avgpool_test_case() compare_graphs(expr, ref_expr)