Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon][UnitTest] Disable flaky quantization test #16337

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 69 additions & 46 deletions tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,24 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=redefined-outer-name

""" Tests for avg_pool2d fake quantization to integer """

import numpy as np
import pytest

import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relay
from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET
from .infrastructure import quantize_np, build_module, run_module


def compare_graphs(expr, ref_expr):
"""Compares the given graph with the expected graph"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
ref_mod = tvm.IRModule.from_expr(ref_expr)
ref_mod = tvm.relay.transform.InferType()(ref_mod)
assert tvm.ir.structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True)


def compare_fq_to_int(hexagon_session, expr, inputs):
"""Compares the float module output with the integer module output"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

mod = build_module(
mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)
mod_int = build_module(
mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)

hexagon_mod = hexagon_session.get_executor_from_factory(mod)
result = run_module(hexagon_mod, inputs)

hexagon_mod = hexagon_session.get_executor_from_factory(mod_int)
result_int = run_module(hexagon_mod, inputs)

tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02)
from .infrastructure import quantize_np, build_module, run_module


@tvm.testing.requires_hexagon
def test_avgpool_conv2d(hexagon_session: Session):
def _make_avgpool_conv2d():
"""Test case with avg_pool2d followed by a conv2d"""
dtype = "int8"
shape_x = [1, 2, 9, 9]
Expand Down Expand Up @@ -112,8 +83,6 @@ def test_avgpool_conv2d(hexagon_session: Session):
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
args = {"input": input_quant, "weight": weight_quant}

compare_fq_to_int(hexagon_session, expr, args)

# Expected graph
op0 = relay.qnn.op.avg_pool2d(
inp,
Expand Down Expand Up @@ -148,11 +117,11 @@ def test_avgpool_conv2d(hexagon_session: Session):
out_dtype="int8",
)
ref_expr = relay.qnn.op.dequantize(op2, out_sc, out_zp)
compare_graphs(expr, ref_expr)

return expr, args, ref_expr

@tvm.testing.requires_hexagon
def test_avgpool_avgpool(hexagon_session: Session):

def _make_avgpool_avgpool():
"""Test case with avg_pool2d followed by an avg_pool2d"""
dtype = "uint8"
shape_x = [1, 2, 9, 9]
Expand Down Expand Up @@ -197,7 +166,6 @@ def test_avgpool_avgpool(hexagon_session: Session):
expr = relay.qnn.op.quantize(op2, out_sc, out_zp, out_dtype=dtype)
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
args = {"input": input_quant}
compare_fq_to_int(hexagon_session, expr, args)

# Expected graph
op0 = relay.qnn.op.avg_pool2d(
Expand Down Expand Up @@ -227,12 +195,11 @@ def test_avgpool_avgpool(hexagon_session: Session):
count_include_pad=False,
)
ref_expr = relay.qnn.op.dequantize(op1, out_sc, out_zp)
compare_graphs(expr, ref_expr)

return expr, args, ref_expr

@tvm.testing.requires_hexagon
def test_avgpool(hexagon_session: Session):
"""Test case of a regular avg_pool2d"""

def _make_avgpool():
dtype = "int8"
shape_x = [1, 2, 9, 9]
kernel = [3, 3]
Expand Down Expand Up @@ -266,7 +233,6 @@ def test_avgpool(hexagon_session: Session):
expr = relay.qnn.op.quantize(op1, out_sc, out_zp, out_dtype=dtype)
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
args = {"input": input_quant}
compare_fq_to_int(hexagon_session, expr, args)

# Expected graph
op = relay.qnn.op.avg_pool2d(
Expand All @@ -283,6 +249,63 @@ def test_avgpool(hexagon_session: Session):
count_include_pad=False,
)
ref_expr = relay.qnn.op.dequantize(op, out_sc, out_zp)

return expr, args, ref_expr


def compare_graphs(expr, ref_expr):
"""Compares the given graph with the expected graph"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
ref_mod = tvm.IRModule.from_expr(ref_expr)
ref_mod = tvm.relay.transform.InferType()(ref_mod)
tvm.ir.assert_structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True)


def compare_fq_to_int(hexagon_session, expr, inputs):
"""Compares the float module output with the integer module output"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

mod = build_module(
mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)
mod_int = build_module(
mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)

hexagon_mod = hexagon_session.get_executor_from_factory(mod)
result = run_module(hexagon_mod, inputs)

hexagon_mod = hexagon_session.get_executor_from_factory(mod_int)
result_int = run_module(hexagon_mod, inputs)

tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02)


avgpool_test_case = tvm.testing.parameter(
_make_avgpool,
_make_avgpool_avgpool,
pytest.param(
_make_avgpool_conv2d,
marks=pytest.mark.xfail(
reason="Rounding differences causing mismatch of Constant, difference around 10^-7"
),
),
)


@tvm.testing.requires_hexagon
def test_execution(hexagon_session: Session, avgpool_test_case):
expr, args, _ = avgpool_test_case()
compare_fq_to_int(hexagon_session, expr, args)


def test_quantization(avgpool_test_case):
expr, _, ref_expr = avgpool_test_case()
compare_graphs(expr, ref_expr)


Expand Down
Loading