diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index ff673d23144a..c1a7b50d3f45 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -244,6 +244,16 @@ def multiply_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + if rhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and lhs_kind is None: + # quantize rhs to INPUT field + if rhs_kind == QAnnotateKind.ACTIVATION: + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) + if _analysis.check_constant(lhs_expr): + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT) + else: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 563d28366874..ca980dd9f8fc 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -133,6 +133,11 @@ def mul_partition_generic(ref_call, new_args, ctx): lhs = new_args[0].realize() return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + if rhs_cond: + # introduced by efficientnet + rhs = new_args[1].realize() + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + if not lhs_cond and not rhs_cond: # trivial case return None