From ee41a27f060f85214ff67dab1c060bc76495c77c Mon Sep 17 00:00:00 2001
From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com>
Date: Mon, 19 Dec 2022 16:28:39 +0400
Subject: [PATCH 1/5] [microNPU] Add relu6 relu_n1_to_1 test cases for Ethos-U

Tests are extended with cases with activations relu6 and relu_n1_to_1.
Does not fuse min and max operations with requantize if there are different scales as it is not supported on NPU.
---
 python/tvm/relay/op/contrib/ethosu.py         | 62 ++++++++++++----
 .../contrib/test_ethosu/test_codegen.py       | 74 ++++++++++++-------
 .../contrib/test_ethosu/test_legalize.py      | 10 +++
 3 files changed, 103 insertions(+), 43 deletions(-)

diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index c0f8e5e9708e..e1ec3172428b 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -688,15 +688,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation:
         clip = None
         requantize = None
 
-        if is_quantized_operation:
-            if str(current_call.op) == "clip":
-                clip = current_call
-                current_call = clip.args[0]
-        else:
-            if str(current_call.op) == "qnn.requantize":
-                requantize = current_call
-                clip = current_call.args[0]
-                current_call = clip.args[0]
+        if str(current_call.op) == "clip":
+            clip = current_call
+            current_call = clip.args[0]
+        elif str(current_call.op) == "qnn.requantize":
+            requantize = current_call
+            clip = current_call.args[0]
+            current_call = clip.args[0]
         binary_op = current_call
 
         layout = "NHWC"
@@ -929,6 +927,9 @@ def is_valid(self):
             [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
         ):
             return False
+        # MIN with different scales is not supported on NPU.
+        if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
+            return False
         return True
 
 
@@ -938,12 +939,21 @@ def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     minimum = is_op("minimum")(wildcard(), wildcard())
     optional_min_clip = is_op("clip")(minimum)
-    optional_min_clip = is_op("qnn.requantize")(
-        optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
-    )
     return minimum | optional_min_clip
 
 
+def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for minimum with fused RELU activation.
+    """
+    pattern = is_op("minimum")(wildcard(), wildcard())
+    pattern = is_op("clip")(pattern)
+    pattern = is_op("qnn.requantize")(
+        pattern, is_constant(), is_constant(), is_constant(), is_constant()
+    )
+    return pattern
+
+
 class MaxParams(BinaryElementwiseParams):
     """
     This class will parse a call to a ethosu.binary_elementwise Max composite function
@@ -967,6 +977,9 @@ def is_valid(self):
             [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
         ):
             return False
+        # MAX with different scales is not supported on NPU.
+        if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
+            return False
         return True
 
 
@@ -976,12 +989,21 @@ def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     maximum = is_op("maximum")(wildcard(), wildcard())
     optional_max_clip = is_op("clip")(maximum)
-    optional_max_clip = is_op("qnn.requantize")(
-        optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
-    )
     return maximum | optional_max_clip
 
 
+def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for maximum with fused RELU activation.
+    """
+    pattern = is_op("maximum")(wildcard(), wildcard())
+    pattern = is_op("clip")(pattern)
+    pattern = is_op("qnn.requantize")(
+        pattern, is_constant(), is_constant(), is_constant(), is_constant()
+    )
+    return pattern
+
+
 class ShlParams(BinaryElementwiseParams):
     """
     This class will parse a call to a ethosu.binary_elementwise Shl composite function
@@ -1820,11 +1842,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
             qnn_mul_pattern(),
             lambda pat: MulParams(pat).is_valid(),
         ),
+        (
+            MinParams.composite_name,
+            minimum_clip_requantize_pattern(),
+            lambda pat: MinParams(pat).is_valid(),
+        ),
         (
             MinParams.composite_name,
             minimum_pattern(),
             lambda pat: MinParams(pat).is_valid(),
         ),
+        (
+            MaxParams.composite_name,
+            maximum_clip_requantize_pattern(),
+            lambda pat: MaxParams(pat).is_valid(),
+        ),
         (
             MaxParams.composite_name,
             maximum_pattern(),
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index e06e36638d7f..8d086742562f 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -37,6 +37,16 @@
 ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"]
 
 
+def relu_n1_to_1(x):
+    """
+    The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
+    """
+    return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))
+
+
+ACTIVATIONS = [None, tf.nn.relu, tf.nn.relu6, relu_n1_to_1]
+
+
 def is_u55_accel_type(accel_type):
     return "u55" in accel_type
 
@@ -46,7 +56,7 @@ def is_u55_accel_type(accel_type):
 @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
-@pytest.mark.parametrize("activation", ["NONE", "RELU"])
+@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_ethosu_conv2d_single(
     ifm_shape,
     kernel_shape,
@@ -72,8 +82,8 @@ def conv2d(x):
             padding=padding,
             dilations=dilation,
         )
-        if activation == "RELU":
-            op = tf.nn.relu(op)
+        if activation:
+            op = activation(op)
         return op
 
     infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type)
@@ -114,7 +124,7 @@ def conv2d(x):
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"])
-@pytest.mark.parametrize("activation", ["NONE", "RELU"])
+@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_ethosu_conv2d_double(
     ifm_shape,
     kernel_shape,
@@ -150,22 +160,28 @@ def conv2d_double(x):
             padding=padding,
             dilations=dilation,
         )
-        if activation == "RELU":
-            op2 = tf.nn.relu(op2)
+        if activation:
+            op2 = activation(op2)
         return op2
 
     infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)
 
 
 @pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)])
-def test_out_of_range_scaling(weight_min, weight_max):
+# relu6 and relu_n1_to_1 operations are excluded from activations since tflite results are different.
+# In the tflite model, a rather large scale is generated, so in some cases in tflite result is -128 in ethosu 127.
+@pytest.mark.parametrize("activation", [None, tf.nn.relu])
+def test_out_of_range_scaling(
+    weight_min,
+    weight_max,
+    activation,
+):
     np.random.seed(0)
     ifm_shape = (1, 6, 6, 2)
     strides = (1, 1)
     kernel_shape = (1, 1)
     dilation = (1, 1)
     padding = "SAME"
-    activation = "RELU"
     accel_type = "ethos-u55-128"
 
     @tf.function
@@ -186,8 +202,8 @@ def conv_invalid_scale(x):
             padding=padding,
             dilations=dilation,
         )
-        if activation == "RELU":
-            op = tf.nn.relu(op)
+        if activation:
+            op = activation(op)
         return op
 
     infra.compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type)
@@ -196,11 +212,12 @@ def conv_invalid_scale(x):
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 @pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
 @pytest.mark.parametrize(
-    "kernel_shape, activation_function",
-    [((3, 3), "RELU"), ((1, 2), "NONE")],
+    "kernel_shape",
+    [(3, 3), (1, 2)],
 )
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))])
+@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_tflite_depthwise_conv2d(
     accel_type,
     ifm_shape,
@@ -208,7 +225,7 @@ def test_tflite_depthwise_conv2d(
     padding,
     strides,
     dilation,
-    activation_function,
+    activation,
 ):
     np.random.seed(0)
 
@@ -221,8 +238,8 @@ def depthwise_conv2d(x):
         op = tf.nn.depthwise_conv2d(
             x, weight, strides=tf_strides, padding=padding, dilations=dilation
         )
-        if activation_function == "RELU":
-            op = tf.nn.relu(op)
+        if activation:
+            op = activation(op)
         return op
 
     infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
@@ -265,17 +282,18 @@ def depthwise_conv2d(x):
 @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
 @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]])
 @pytest.mark.parametrize(
-    "pool_shape, strides, activation_function, padding",
-    [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")],
+    "pool_shape, strides, padding",
+    [([1, 2], [1, 2], "SAME"), ([2, 3], [2, 3], "VALID")],
 )
+@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_ethosu_pooling(
     accel_type,
     ifm_shape,
     pooling_type,
     strides,
     pool_shape,
-    activation_function,
     padding,
+    activation,
 ):
     np.random.seed(0)
 
@@ -285,8 +303,8 @@ def pooling(x):
             op = tf.nn.max_pool(x, pool_shape, strides, padding)
         elif pooling_type == "AVG":
             op = tf.nn.avg_pool(x, pool_shape, strides, padding)
-        if activation_function == "RELU":
-            op = tf.nn.relu(op)
+        if activation:
+            op = activation(op)
         return op
 
     infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type)
@@ -303,13 +321,13 @@ def pooling(x):
         ([1, 4, 4], [4, 1]),
     ],
 )
-@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
+@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_ethosu_binary_elementwise(
     accel_type,
     operator_type,
     ifm_shape,
     ifm2_shape,
-    activation_function,
+    activation,
 ):
     np.random.seed(0)
 
@@ -325,8 +343,8 @@ def binary_elementwise(lhs, rhs):
             op = tf.math.minimum(lhs, rhs)
         elif operator_type == "MAX":
             op = tf.math.maximum(lhs, rhs)
-        if activation_function == "RELU":
-            op = tf.nn.relu(op)
+        if activation:
+            op = activation(op)
         return op
 
     infra.compare_tvm_with_tflite(
@@ -1113,13 +1131,13 @@ def leaky_relu_func(x):
 @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
 @pytest.mark.parametrize("ofm_channels", [32, 64])
 @pytest.mark.parametrize("use_bias", [True, False])
-@pytest.mark.parametrize("activation_function", ["RELU", "NONE"])
+@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_tflite_fully_connected(
     accel_type,
     ifm_shape,
     ofm_channels,
     use_bias,
-    activation_function,
+    activation,
 ):
     np.random.seed(0)
 
@@ -1134,8 +1152,8 @@ def fully_connected(x):
         x = tf.matmul(x, w)
         if use_bias:
             x = tf.nn.bias_add(x, bias)
-        if activation_function:
-            x = tf.nn.relu(x)
+        if activation:
+            x = activation(x)
         return x
 
     infra.compare_tvm_with_tflite(
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 9b4dd467ff9f..7641d9c4e6f0 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -899,6 +899,11 @@ def verify(ext_func):
     elif operator_type == "MIN":
         rewriter = legalize.MinRewriter()
         pattern_table = [
+            (
+                ethosu.MinParams.composite_name,
+                ethosu.minimum_clip_requantize_pattern(),
+                lambda pat: ethosu.MinParams(pat).is_valid(),
+            ),
             (
                 ethosu.MinParams.composite_name,
                 ethosu.minimum_pattern(),
@@ -908,6 +913,11 @@ def verify(ext_func):
     elif operator_type == "MAX":
         rewriter = legalize.MaxRewriter()
         pattern_table = [
+            (
+                ethosu.MaxParams.composite_name,
+                ethosu.maximum_clip_requantize_pattern(),
+                lambda pat: ethosu.MaxParams(pat).is_valid(),
+            ),
             (
                 ethosu.MaxParams.composite_name,
                 ethosu.maximum_pattern(),

From 2da2b0e4663decef6461bf1bda963cf8ab96acdd Mon Sep 17 00:00:00 2001
From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com>
Date: Wed, 28 Dec 2022 18:28:49 +0400
Subject: [PATCH 2/5] add separate tests for relu6, relu_n1_to_1

---
 .../contrib/test_ethosu/test_codegen.py       | 113 +++++++++++-------
 1 file changed, 67 insertions(+), 46 deletions(-)

diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index 8d086742562f..b318426c08c8 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -37,16 +37,6 @@
 ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"]
 
 
-def relu_n1_to_1(x):
-    """
-    The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
-    """
-    return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))
-
-
-ACTIVATIONS = [None, tf.nn.relu, tf.nn.relu6, relu_n1_to_1]
-
-
 def is_u55_accel_type(accel_type):
     return "u55" in accel_type
 
@@ -56,7 +46,7 @@ def is_u55_accel_type(accel_type):
 @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
-@pytest.mark.parametrize("activation", ACTIVATIONS)
+@pytest.mark.parametrize("activation", ["NONE", "RELU"])
 def test_ethosu_conv2d_single(
     ifm_shape,
     kernel_shape,
@@ -82,8 +72,8 @@ def conv2d(x):
             padding=padding,
             dilations=dilation,
         )
-        if activation:
-            op = activation(op)
+        if activation == "RELU":
+            op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type)
@@ -124,7 +114,7 @@ def conv2d(x):
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"])
-@pytest.mark.parametrize("activation", ACTIVATIONS)
+@pytest.mark.parametrize("activation", ["NONE", "RELU"])
 def test_ethosu_conv2d_double(
     ifm_shape,
     kernel_shape,
@@ -160,28 +150,22 @@ def conv2d_double(x):
             padding=padding,
             dilations=dilation,
         )
-        if activation:
-            op2 = activation(op2)
+        if activation == "RELU":
+            op2 = tf.nn.relu(op2)
         return op2
 
     infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)
 
 
 @pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)])
-# relu6 and relu_n1_to_1 operations are excluded from activations since tflite results are different.
-# In the tflite model, a rather large scale is generated, so in some cases in tflite result is -128 in ethosu 127.
-@pytest.mark.parametrize("activation", [None, tf.nn.relu])
-def test_out_of_range_scaling(
-    weight_min,
-    weight_max,
-    activation,
-):
+def test_out_of_range_scaling(weight_min, weight_max):
     np.random.seed(0)
     ifm_shape = (1, 6, 6, 2)
     strides = (1, 1)
     kernel_shape = (1, 1)
     dilation = (1, 1)
     padding = "SAME"
+    activation = "RELU"
     accel_type = "ethos-u55-128"
 
     @tf.function
@@ -202,8 +186,8 @@ def conv_invalid_scale(x):
             padding=padding,
             dilations=dilation,
         )
-        if activation:
-            op = activation(op)
+        if activation == "RELU":
+            op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type)
@@ -212,12 +196,11 @@ def conv_invalid_scale(x):
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 @pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
 @pytest.mark.parametrize(
-    "kernel_shape",
-    [(3, 3), (1, 2)],
+    "kernel_shape, activation_function",
+    [((3, 3), "RELU"), ((1, 2), "NONE")],
 )
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))])
-@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_tflite_depthwise_conv2d(
     accel_type,
     ifm_shape,
@@ -225,7 +208,7 @@ def test_tflite_depthwise_conv2d(
     padding,
     strides,
     dilation,
-    activation,
+    activation_function,
 ):
     np.random.seed(0)
 
@@ -238,8 +221,8 @@ def depthwise_conv2d(x):
         op = tf.nn.depthwise_conv2d(
             x, weight, strides=tf_strides, padding=padding, dilations=dilation
         )
-        if activation:
-            op = activation(op)
+        if activation_function == "RELU":
+            op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
@@ -282,18 +265,17 @@ def depthwise_conv2d(x):
 @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
 @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]])
 @pytest.mark.parametrize(
-    "pool_shape, strides, padding",
-    [([1, 2], [1, 2], "SAME"), ([2, 3], [2, 3], "VALID")],
+    "pool_shape, strides, activation_function, padding",
+    [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")],
 )
-@pytest.mark.parametrize("activation", ACTIVATIONS)
 def test_ethosu_pooling(
     accel_type,
     ifm_shape,
     pooling_type,
     strides,
     pool_shape,
+    activation_function,
     padding,
-    activation,
 ):
     np.random.seed(0)
 
@@ -303,8 +285,8 @@ def pooling(x):
             op = tf.nn.max_pool(x, pool_shape, strides, padding)
         elif pooling_type == "AVG":
             op = tf.nn.avg_pool(x, pool_shape, strides, padding)
-        if activation:
-            op = activation(op)
+        if activation_function == "RELU":
+            op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type)
@@ -321,13 +303,13 @@ def pooling(x):
         ([1, 4, 4], [4, 1]),
     ],
 )
-@pytest.mark.parametrize("activation", ACTIVATIONS)
+@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
 def test_ethosu_binary_elementwise(
     accel_type,
     operator_type,
     ifm_shape,
     ifm2_shape,
-    activation,
+    activation_function,
 ):
     np.random.seed(0)
 
@@ -343,8 +325,8 @@ def binary_elementwise(lhs, rhs):
             op = tf.math.minimum(lhs, rhs)
         elif operator_type == "MAX":
             op = tf.math.maximum(lhs, rhs)
-        if activation:
-            op = activation(op)
+        if activation_function == "RELU":
+            op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(
@@ -1127,17 +1109,56 @@ def leaky_relu_func(x):
     )
 
 
+def test_tflite_relu6():
+    np.random.seed(0)
+    accel_type = "ethos-u55-128"
+    ifm_shape = (1, 12, 16, 8)
+
+    @tf.function
+    def relu6(x):
+        return tf.nn.relu6(x)
+
+    infra.compare_tvm_with_tflite(
+        relu6,
+        [ifm_shape],
+        accel_type,
+        enable_cascader=is_u55_accel_type(accel_type),
+        ranges=[(-1, 1)],
+    )
+
+
+def test_tflite_relu_n1_to_1():
+    np.random.seed(0)
+    accel_type = "ethos-u55-128"
+    ifm_shape = (1, 12, 16, 8)
+
+    @tf.function
+    def relu_n1_to_1(x):
+        """
+        The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
+        """
+        return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))
+
+    infra.compare_tvm_with_tflite(
+        relu_n1_to_1,
+        [ifm_shape],
+        accel_type,
+        enable_cascader=is_u55_accel_type(accel_type),
+        ranges=[(-1, 1)],
+    )
+
+
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
 @pytest.mark.parametrize("ofm_channels", [32, 64])
 @pytest.mark.parametrize("use_bias", [True, False])
-@pytest.mark.parametrize("activation", ACTIVATIONS)
+@pytest.mark.parametrize("activation_function", ["RELU", "NONE"])
 def test_tflite_fully_connected(
     accel_type,
     ifm_shape,
     ofm_channels,
     use_bias,
-    activation,
+    activation_function,
 ):
     np.random.seed(0)
 
@@ -1152,8 +1173,8 @@ def fully_connected(x):
         x = tf.matmul(x, w)
         if use_bias:
             x = tf.nn.bias_add(x, bias)
-        if activation:
-            x = activation(x)
+        if activation_function:
+            x = tf.nn.relu(x)
         return x
 
     infra.compare_tvm_with_tflite(

From f692260b2b8bdbb904ec22f10c212edb33957bdf Mon Sep 17 00:00:00 2001
From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com>
Date: Fri, 30 Dec 2022 17:23:46 +0400
Subject: [PATCH 3/5] remain test max_relu_n1_to_1

Remain relu_n1_to_1 for testing case when activation cannot be fused with the previous operation
---
 .../contrib/test_ethosu/test_codegen.py       | 33 ++++---------------
 1 file changed, 7 insertions(+), 26 deletions(-)

diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index b318426c08c8..62372a05d39f 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1109,42 +1109,23 @@ def leaky_relu_func(x):
     )
 
 
-def test_tflite_relu6():
-    np.random.seed(0)
-    accel_type = "ethos-u55-128"
-    ifm_shape = (1, 12, 16, 8)
-
-    @tf.function
-    def relu6(x):
-        return tf.nn.relu6(x)
-
-    infra.compare_tvm_with_tflite(
-        relu6,
-        [ifm_shape],
-        accel_type,
-        enable_cascader=is_u55_accel_type(accel_type),
-        ranges=[(-1, 1)],
-    )
-
-
 def test_tflite_relu_n1_to_1():
     np.random.seed(0)
     accel_type = "ethos-u55-128"
     ifm_shape = (1, 12, 16, 8)
 
     @tf.function
-    def relu_n1_to_1(x):
-        """
-        The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
-        """
-        return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))
+    def max_relu_n1_to_1(lhs, rhs):
+        op = tf.math.maximum(lhs, rhs)
+        # The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
+        return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0))
 
     infra.compare_tvm_with_tflite(
-        relu_n1_to_1,
-        [ifm_shape],
+        max_relu_n1_to_1,
+        [ifm_shape, ifm_shape],
         accel_type,
         enable_cascader=is_u55_accel_type(accel_type),
-        ranges=[(-1, 1)],
+        ranges=[(-1, 1), (0, 2)],
     )
 
 

From 87b449d30342b932c4223370c061910823e748ca Mon Sep 17 00:00:00 2001
From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com>
Date: Wed, 11 Jan 2023 18:13:37 +0400
Subject: [PATCH 4/5] add relu6, relu_n1_to_1 test cases for Ethos-U, remove
 changes for BinaryElementwise operation

---
 python/tvm/relay/op/contrib/ethosu.py         | 62 +++++------------
 .../contrib/test_ethosu/test_codegen.py       | 67 ++++++++++++++++---
 .../contrib/test_ethosu/test_legalize.py      | 10 ---
 3 files changed, 74 insertions(+), 65 deletions(-)

diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index e1ec3172428b..c0f8e5e9708e 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -688,13 +688,15 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation:
         clip = None
         requantize = None
 
-        if str(current_call.op) == "clip":
-            clip = current_call
-            current_call = clip.args[0]
-        elif str(current_call.op) == "qnn.requantize":
-            requantize = current_call
-            clip = current_call.args[0]
-            current_call = clip.args[0]
+        if is_quantized_operation:
+            if str(current_call.op) == "clip":
+                clip = current_call
+                current_call = clip.args[0]
+        else:
+            if str(current_call.op) == "qnn.requantize":
+                requantize = current_call
+                clip = current_call.args[0]
+                current_call = clip.args[0]
         binary_op = current_call
 
         layout = "NHWC"
@@ -927,9 +929,6 @@ def is_valid(self):
             [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
         ):
             return False
-        # MIN with different scales is not supported on NPU.
-        if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
-            return False
         return True
 
 
@@ -939,19 +938,10 @@ def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     minimum = is_op("minimum")(wildcard(), wildcard())
     optional_min_clip = is_op("clip")(minimum)
-    return minimum | optional_min_clip
-
-
-def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
-    """
-    This function creates the pattern for minimum with fused RELU activation.
-    """
-    pattern = is_op("minimum")(wildcard(), wildcard())
-    pattern = is_op("clip")(pattern)
-    pattern = is_op("qnn.requantize")(
-        pattern, is_constant(), is_constant(), is_constant(), is_constant()
+    optional_min_clip = is_op("qnn.requantize")(
+        optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
     )
-    return pattern
+    return minimum | optional_min_clip
 
 
 class MaxParams(BinaryElementwiseParams):
@@ -977,9 +967,6 @@ def is_valid(self):
             [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
         ):
             return False
-        # MAX with different scales is not supported on NPU.
-        if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
-            return False
         return True
 
 
@@ -989,19 +976,10 @@ def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     maximum = is_op("maximum")(wildcard(), wildcard())
     optional_max_clip = is_op("clip")(maximum)
-    return maximum | optional_max_clip
-
-
-def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
-    """
-    This function creates the pattern for maximum with fused RELU activation.
-    """
-    pattern = is_op("maximum")(wildcard(), wildcard())
-    pattern = is_op("clip")(pattern)
-    pattern = is_op("qnn.requantize")(
-        pattern, is_constant(), is_constant(), is_constant(), is_constant()
+    optional_max_clip = is_op("qnn.requantize")(
+        optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
     )
-    return pattern
+    return maximum | optional_max_clip
 
 
 class ShlParams(BinaryElementwiseParams):
@@ -1842,21 +1820,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
             qnn_mul_pattern(),
             lambda pat: MulParams(pat).is_valid(),
         ),
-        (
-            MinParams.composite_name,
-            minimum_clip_requantize_pattern(),
-            lambda pat: MinParams(pat).is_valid(),
-        ),
         (
             MinParams.composite_name,
             minimum_pattern(),
             lambda pat: MinParams(pat).is_valid(),
         ),
-        (
-            MaxParams.composite_name,
-            maximum_clip_requantize_pattern(),
-            lambda pat: MaxParams(pat).is_valid(),
-        ),
         (
             MaxParams.composite_name,
             maximum_pattern(),
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index 62372a05d39f..afd5486f360d 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1109,23 +1109,74 @@ def leaky_relu_func(x):
     )
 
 
+# conv2d + relu_n1_to_1 is used because separate activation is not offloaded to NPU.
 def test_tflite_relu_n1_to_1():
     np.random.seed(0)
-    accel_type = "ethos-u55-128"
-    ifm_shape = (1, 12, 16, 8)
+    accel_type = "ethos-u55-256"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    padding = (1, 0, 1, 1)
 
     @tf.function
-    def max_relu_n1_to_1(lhs, rhs):
-        op = tf.math.maximum(lhs, rhs)
+    def conv2d_relu_n1_to_1(x):
+        tf_strides = [1, strides[0], strides[1], 1]
+        op = tf.pad(
+            x,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
+        weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            op,
+            weight,
+            strides=tf_strides,
+            padding="VALID",
+        )
         # The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
         return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0))
 
     infra.compare_tvm_with_tflite(
-        max_relu_n1_to_1,
-        [ifm_shape, ifm_shape],
+        conv2d_relu_n1_to_1,
+        [ifm_shape],
         accel_type,
-        enable_cascader=is_u55_accel_type(accel_type),
-        ranges=[(-1, 1), (0, 2)],
+        enable_cascader=True,
+    )
+
+
+# conv2d + relu6 is used because separate activation is not offloaded to NPU.
+def test_tflite_relu6():
+    np.random.seed(0)
+    accel_type = "ethos-u55-256"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    padding = (0, 0, 1, 1)
+
+    @tf.function
+    def conv2d_relu6(x):
+        tf_strides = [1, strides[0], strides[1], 1]
+        op = tf.pad(
+            x,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
+        weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            op,
+            weight,
+            strides=tf_strides,
+            padding="VALID",
+        )
+        return tf.nn.relu6(op)
+
+    infra.compare_tvm_with_tflite(
+        conv2d_relu6,
+        [ifm_shape],
+        accel_type,
+        enable_cascader=True,
     )
 
 
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 7641d9c4e6f0..9b4dd467ff9f 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -899,11 +899,6 @@ def verify(ext_func):
     elif operator_type == "MIN":
         rewriter = legalize.MinRewriter()
         pattern_table = [
-            (
-                ethosu.MinParams.composite_name,
-                ethosu.minimum_clip_requantize_pattern(),
-                lambda pat: ethosu.MinParams(pat).is_valid(),
-            ),
             (
                 ethosu.MinParams.composite_name,
                 ethosu.minimum_pattern(),
@@ -913,11 +908,6 @@ def verify(ext_func):
     elif operator_type == "MAX":
         rewriter = legalize.MaxRewriter()
         pattern_table = [
-            (
-                ethosu.MaxParams.composite_name,
-                ethosu.maximum_clip_requantize_pattern(),
-                lambda pat: ethosu.MaxParams(pat).is_valid(),
-            ),
             (
                 ethosu.MaxParams.composite_name,
                 ethosu.maximum_pattern(),

From 1332b95f42647169f83cbad6ba8a2bba47a253ef Mon Sep 17 00:00:00 2001
From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com>
Date: Thu, 12 Jan 2023 11:56:15 +0400
Subject: [PATCH 5/5] remove pad operation from relu6, relu_n1_to_1 test cases

---
 tests/python/contrib/test_ethosu/test_codegen.py | 16 ++--------------
 1 file changed, 2 insertions(+), 14 deletions(-)

diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index afd5486f360d..c164b177c21c 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1116,20 +1116,14 @@ def test_tflite_relu_n1_to_1():
     ifm_shape = (1, 55, 34, 3)
     kernel_shape = (3, 2)
     strides = (1, 1)
-    padding = (1, 0, 1, 1)
 
     @tf.function
     def conv2d_relu_n1_to_1(x):
         tf_strides = [1, strides[0], strides[1], 1]
-        op = tf.pad(
-            x,
-            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
-            "CONSTANT",
-        )
         weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
         weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
         op = tf.nn.conv2d(
-            op,
+            x,
             weight,
             strides=tf_strides,
             padding="VALID",
@@ -1152,20 +1146,14 @@ def test_tflite_relu6():
     ifm_shape = (1, 55, 34, 3)
     kernel_shape = (3, 2)
     strides = (1, 1)
-    padding = (0, 0, 1, 1)
 
     @tf.function
     def conv2d_relu6(x):
         tf_strides = [1, strides[0], strides[1], 1]
-        op = tf.pad(
-            x,
-            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
-            "CONSTANT",
-        )
         weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
         weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
         op = tf.nn.conv2d(
-            op,
+            x,
             weight,
             strides=tf_strides,
             padding="VALID",