From ccc0b9162f2e983a8810e99c903c7141dbec81b6 Mon Sep 17 00:00:00 2001 From: Perry Gibson <Wheest@users.noreply.github.com> Date: Tue, 14 Mar 2023 02:36:24 +0000 Subject: [PATCH] [fix][relay][qnn] Bug fix for 8-bit quantized mul (#14286) * [fix][relay][qnn] Bug fix for 8-bit quantized mul * Update _annotate.py Revert black formatting from my text editor, which I had assumed matched TVM's linter --- python/tvm/relay/quantize/_annotate.py | 10 ++++++++++ python/tvm/relay/quantize/_partition.py | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index ff673d23144a..c1a7b50d3f45 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -244,6 +244,16 @@ def multiply_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + if rhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and lhs_kind is None: + # quantize rhs to INPUT field + if rhs_kind == QAnnotateKind.ACTIVATION: + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) + if _analysis.check_constant(lhs_expr): + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT) + else: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 563d28366874..ca980dd9f8fc 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -133,6 +133,11 @@ def mul_partition_generic(ref_call, new_args, ctx): lhs = new_args[0].realize() return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + if rhs_cond: + # introduced by efficientnet + rhs = new_args[1].realize() + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + if not lhs_cond and not rhs_cond: # trivial case return None