From ccc0b9162f2e983a8810e99c903c7141dbec81b6 Mon Sep 17 00:00:00 2001
From: Perry Gibson <Wheest@users.noreply.github.com>
Date: Tue, 14 Mar 2023 02:36:24 +0000
Subject: [PATCH] [fix][relay][qnn] Bug fix for 8-bit quantized mul (#14286)

* [fix][relay][qnn] Bug fix for 8-bit quantized mul

* Update _annotate.py

Revert black formatting from my text editor, which I had assumed matched TVM's linter
---
 python/tvm/relay/quantize/_annotate.py  | 10 ++++++++++
 python/tvm/relay/quantize/_partition.py |  5 +++++
 2 files changed, 15 insertions(+)

diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py
index ff673d23144a..c1a7b50d3f45 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -244,6 +244,16 @@ def multiply_rewrite(ref_call, new_args, ctx):
             rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
         expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
         return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
+    if rhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and lhs_kind is None:
+        # quantize rhs to INPUT field
+        if rhs_kind == QAnnotateKind.ACTIVATION:
+            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
+        if _analysis.check_constant(lhs_expr):
+            lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT)
+        else:
+            lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
+        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
+        return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
     raise ValueError
 
diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py
index 563d28366874..ca980dd9f8fc 100644
--- a/python/tvm/relay/quantize/_partition.py
+++ b/python/tvm/relay/quantize/_partition.py
@@ -133,6 +133,11 @@ def mul_partition_generic(ref_call, new_args, ctx):
         lhs = new_args[0].realize()
         return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
 
+    if rhs_cond:
+        # introduced by efficientnet
+        rhs = new_args[1].realize()
+        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+
     if not lhs_cond and not rhs_cond:
         # trivial case
         return None