From ee41a27f060f85214ff67dab1c060bc76495c77c Mon Sep 17 00:00:00 2001 From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Mon, 19 Dec 2022 16:28:39 +0400 Subject: [PATCH 1/5] [microNPU] Add relu6 relu_n1_to_1 test cases for Ethos-U Tests are extended with cases with activations relu6 and relu_n1_to_1. Does not fuse min and max operations with requantize if there are different scales as it is not supported on NPU. --- python/tvm/relay/op/contrib/ethosu.py | 62 ++++++++++++---- .../contrib/test_ethosu/test_codegen.py | 74 ++++++++++++------- .../contrib/test_ethosu/test_legalize.py | 10 +++ 3 files changed, 103 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index c0f8e5e9708e..e1ec3172428b 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -688,15 +688,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation: clip = None requantize = None - if is_quantized_operation: - if str(current_call.op) == "clip": - clip = current_call - current_call = clip.args[0] - else: - if str(current_call.op) == "qnn.requantize": - requantize = current_call - clip = current_call.args[0] - current_call = clip.args[0] + if str(current_call.op) == "clip": + clip = current_call + current_call = clip.args[0] + elif str(current_call.op) == "qnn.requantize": + requantize = current_call + clip = current_call.args[0] + current_call = clip.args[0] binary_op = current_call layout = "NHWC" @@ -929,6 +927,9 @@ def is_valid(self): [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] ): return False + # MIN with different scales is not supported on NPU. + if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32: + return False return True @@ -938,12 +939,21 @@ def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ minimum = is_op("minimum")(wildcard(), wildcard()) optional_min_clip = is_op("clip")(minimum) - optional_min_clip = is_op("qnn.requantize")( - optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant() - ) return minimum | optional_min_clip +def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for minimum with fused RELU activation. + """ + pattern = is_op("minimum")(wildcard(), wildcard()) + pattern = is_op("clip")(pattern) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + return pattern + + class MaxParams(BinaryElementwiseParams): """ This class will parse a call to a ethosu.binary_elementwise Max composite function @@ -967,6 +977,9 @@ def is_valid(self): [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] ): return False + # MAX with different scales is not supported on NPU. + if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32: + return False return True @@ -976,12 +989,21 @@ def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ maximum = is_op("maximum")(wildcard(), wildcard()) optional_max_clip = is_op("clip")(maximum) - optional_max_clip = is_op("qnn.requantize")( - optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant() - ) return maximum | optional_max_clip +def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for maximum with fused RELU activation. + """ + pattern = is_op("maximum")(wildcard(), wildcard()) + pattern = is_op("clip")(pattern) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + return pattern + + class ShlParams(BinaryElementwiseParams): """ This class will parse a call to a ethosu.binary_elementwise Shl composite function @@ -1820,11 +1842,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_mul_pattern(), lambda pat: MulParams(pat).is_valid(), ), + ( + MinParams.composite_name, + minimum_clip_requantize_pattern(), + lambda pat: MinParams(pat).is_valid(), + ), ( MinParams.composite_name, minimum_pattern(), lambda pat: MinParams(pat).is_valid(), ), + ( + MaxParams.composite_name, + maximum_clip_requantize_pattern(), + lambda pat: MaxParams(pat).is_valid(), + ), ( MaxParams.composite_name, maximum_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index e06e36638d7f..8d086742562f 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -37,6 +37,16 @@ ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"] +def relu_n1_to_1(x): + """ + The specific pattern will be replaced into RELU_N1_TO_1 by tflite. + """ + return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0)) + + +ACTIVATIONS = [None, tf.nn.relu, tf.nn.relu6, relu_n1_to_1] + + def is_u55_accel_type(accel_type): return "u55" in accel_type @@ -46,7 +56,7 @@ def is_u55_accel_type(accel_type): @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("activation", ["NONE", "RELU"]) +@pytest.mark.parametrize("activation", ACTIVATIONS) def test_ethosu_conv2d_single( ifm_shape, kernel_shape, @@ -72,8 +82,8 @@ def conv2d(x): padding=padding, dilations=dilation, ) - if activation == "RELU": - op = tf.nn.relu(op) + if activation: + op = activation(op) return op infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type) @@ -114,7 +124,7 @@ def conv2d(x): @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) -@pytest.mark.parametrize("activation", ["NONE", "RELU"]) +@pytest.mark.parametrize("activation", ACTIVATIONS) def test_ethosu_conv2d_double( ifm_shape, kernel_shape, @@ -150,22 +160,28 @@ def conv2d_double(x): padding=padding, dilations=dilation, ) - if activation == "RELU": - op2 = tf.nn.relu(op2) + if activation: + op2 = activation(op2) return op2 infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type) @pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)]) -def test_out_of_range_scaling(weight_min, weight_max): +# relu6 and relu_n1_to_1 operations are excluded from activations since tflite results are different. +# In the tflite model, a rather large scale is generated, so in some cases in tflite result is -128 in ethosu 127. +@pytest.mark.parametrize("activation", [None, tf.nn.relu]) +def test_out_of_range_scaling( + weight_min, + weight_max, + activation, +): np.random.seed(0) ifm_shape = (1, 6, 6, 2) strides = (1, 1) kernel_shape = (1, 1) dilation = (1, 1) padding = "SAME" - activation = "RELU" accel_type = "ethos-u55-128" @tf.function @@ -186,8 +202,8 @@ def conv_invalid_scale(x): padding=padding, dilations=dilation, ) - if activation == "RELU": - op = tf.nn.relu(op) + if activation: + op = activation(op) return op infra.compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type) @@ -196,11 +212,12 @@ def conv_invalid_scale(x): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) @pytest.mark.parametrize( - "kernel_shape, activation_function", - [((3, 3), "RELU"), ((1, 2), "NONE")], + "kernel_shape", + [(3, 3), (1, 2)], ) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))]) +@pytest.mark.parametrize("activation", ACTIVATIONS) def test_tflite_depthwise_conv2d( accel_type, ifm_shape, @@ -208,7 +225,7 @@ def test_tflite_depthwise_conv2d( padding, strides, dilation, - activation_function, + activation, ): np.random.seed(0) @@ -221,8 +238,8 @@ def depthwise_conv2d(x): op = tf.nn.depthwise_conv2d( x, weight, strides=tf_strides, padding=padding, dilations=dilation ) - if activation_function == "RELU": - op = tf.nn.relu(op) + if activation: + op = activation(op) return op infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type) @@ -265,17 +282,18 @@ def depthwise_conv2d(x): @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) @pytest.mark.parametrize( - "pool_shape, strides, activation_function, padding", - [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")], + "pool_shape, strides, padding", + [([1, 2], [1, 2], "SAME"), ([2, 3], [2, 3], "VALID")], ) +@pytest.mark.parametrize("activation", ACTIVATIONS) def test_ethosu_pooling( accel_type, ifm_shape, pooling_type, strides, pool_shape, - activation_function, padding, + activation, ): np.random.seed(0) @@ -285,8 +303,8 @@ def pooling(x): op = tf.nn.max_pool(x, pool_shape, strides, padding) elif pooling_type == "AVG": op = tf.nn.avg_pool(x, pool_shape, strides, padding) - if activation_function == "RELU": - op = tf.nn.relu(op) + if activation: + op = activation(op) return op infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type) @@ -303,13 +321,13 @@ def pooling(x): ([1, 4, 4], [4, 1]), ], ) -@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +@pytest.mark.parametrize("activation", ACTIVATIONS) def test_ethosu_binary_elementwise( accel_type, operator_type, ifm_shape, ifm2_shape, - activation_function, + activation, ): np.random.seed(0) @@ -325,8 +343,8 @@ def binary_elementwise(lhs, rhs): op = tf.math.minimum(lhs, rhs) elif operator_type == "MAX": op = tf.math.maximum(lhs, rhs) - if activation_function == "RELU": - op = tf.nn.relu(op) + if activation: + op = activation(op) return op infra.compare_tvm_with_tflite( @@ -1113,13 +1131,13 @@ def leaky_relu_func(x): @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) @pytest.mark.parametrize("ofm_channels", [32, 64]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) +@pytest.mark.parametrize("activation", ACTIVATIONS) def test_tflite_fully_connected( accel_type, ifm_shape, ofm_channels, use_bias, - activation_function, + activation, ): np.random.seed(0) @@ -1134,8 +1152,8 @@ def fully_connected(x): x = tf.matmul(x, w) if use_bias: x = tf.nn.bias_add(x, bias) - if activation_function: - x = tf.nn.relu(x) + if activation: + x = activation(x) return x infra.compare_tvm_with_tflite( diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 9b4dd467ff9f..7641d9c4e6f0 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -899,6 +899,11 @@ def verify(ext_func): elif operator_type == "MIN": rewriter = legalize.MinRewriter() pattern_table = [ + ( + ethosu.MinParams.composite_name, + ethosu.minimum_clip_requantize_pattern(), + lambda pat: ethosu.MinParams(pat).is_valid(), + ), ( ethosu.MinParams.composite_name, ethosu.minimum_pattern(), @@ -908,6 +913,11 @@ def verify(ext_func): elif operator_type == "MAX": rewriter = legalize.MaxRewriter() pattern_table = [ + ( + ethosu.MaxParams.composite_name, + ethosu.maximum_clip_requantize_pattern(), + lambda pat: ethosu.MaxParams(pat).is_valid(), + ), ( ethosu.MaxParams.composite_name, ethosu.maximum_pattern(), From 2da2b0e4663decef6461bf1bda963cf8ab96acdd Mon Sep 17 00:00:00 2001 From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Wed, 28 Dec 2022 18:28:49 +0400 Subject: [PATCH 2/5] add separate tests for relu6, relu_n1_to_1 --- .../contrib/test_ethosu/test_codegen.py | 113 +++++++++++------- 1 file changed, 67 insertions(+), 46 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 8d086742562f..b318426c08c8 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -37,16 +37,6 @@ ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"] -def relu_n1_to_1(x): - """ - The specific pattern will be replaced into RELU_N1_TO_1 by tflite. - """ - return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0)) - - -ACTIVATIONS = [None, tf.nn.relu, tf.nn.relu6, relu_n1_to_1] - - def is_u55_accel_type(accel_type): return "u55" in accel_type @@ -56,7 +46,7 @@ def is_u55_accel_type(accel_type): @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_single( ifm_shape, kernel_shape, @@ -82,8 +72,8 @@ def conv2d(x): padding=padding, dilations=dilation, ) - if activation: - op = activation(op) + if activation == "RELU": + op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type) @@ -124,7 +114,7 @@ def conv2d(x): @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) -@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_double( ifm_shape, kernel_shape, @@ -160,28 +150,22 @@ def conv2d_double(x): padding=padding, dilations=dilation, ) - if activation: - op2 = activation(op2) + if activation == "RELU": + op2 = tf.nn.relu(op2) return op2 infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type) @pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)]) -# relu6 and relu_n1_to_1 operations are excluded from activations since tflite results are different. -# In the tflite model, a rather large scale is generated, so in some cases in tflite result is -128 in ethosu 127. -@pytest.mark.parametrize("activation", [None, tf.nn.relu]) -def test_out_of_range_scaling( - weight_min, - weight_max, - activation, -): +def test_out_of_range_scaling(weight_min, weight_max): np.random.seed(0) ifm_shape = (1, 6, 6, 2) strides = (1, 1) kernel_shape = (1, 1) dilation = (1, 1) padding = "SAME" + activation = "RELU" accel_type = "ethos-u55-128" @tf.function @@ -202,8 +186,8 @@ def conv_invalid_scale(x): padding=padding, dilations=dilation, ) - if activation: - op = activation(op) + if activation == "RELU": + op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type) @@ -212,12 +196,11 @@ def conv_invalid_scale(x): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)]) @pytest.mark.parametrize( - "kernel_shape", - [(3, 3), (1, 2)], + "kernel_shape, activation_function", + [((3, 3), "RELU"), ((1, 2), "NONE")], ) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))]) -@pytest.mark.parametrize("activation", ACTIVATIONS) def test_tflite_depthwise_conv2d( accel_type, ifm_shape, @@ -225,7 +208,7 @@ def test_tflite_depthwise_conv2d( padding, strides, dilation, - activation, + activation_function, ): np.random.seed(0) @@ -238,8 +221,8 @@ def depthwise_conv2d(x): op = tf.nn.depthwise_conv2d( x, weight, strides=tf_strides, padding=padding, dilations=dilation ) - if activation: - op = activation(op) + if activation_function == "RELU": + op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type) @@ -282,18 +265,17 @@ def depthwise_conv2d(x): @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) @pytest.mark.parametrize( - "pool_shape, strides, padding", - [([1, 2], [1, 2], "SAME"), ([2, 3], [2, 3], "VALID")], + "pool_shape, strides, activation_function, padding", + [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")], ) -@pytest.mark.parametrize("activation", ACTIVATIONS) def test_ethosu_pooling( accel_type, ifm_shape, pooling_type, strides, pool_shape, + activation_function, padding, - activation, ): np.random.seed(0) @@ -303,8 +285,8 @@ def pooling(x): op = tf.nn.max_pool(x, pool_shape, strides, padding) elif pooling_type == "AVG": op = tf.nn.avg_pool(x, pool_shape, strides, padding) - if activation: - op = activation(op) + if activation_function == "RELU": + op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type) @@ -321,13 +303,13 @@ def pooling(x): ([1, 4, 4], [4, 1]), ], ) -@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) def test_ethosu_binary_elementwise( accel_type, operator_type, ifm_shape, ifm2_shape, - activation, + activation_function, ): np.random.seed(0) @@ -343,8 +325,8 @@ def binary_elementwise(lhs, rhs): op = tf.math.minimum(lhs, rhs) elif operator_type == "MAX": op = tf.math.maximum(lhs, rhs) - if activation: - op = activation(op) + if activation_function == "RELU": + op = tf.nn.relu(op) return op infra.compare_tvm_with_tflite( @@ -1127,17 +1109,56 @@ def leaky_relu_func(x): ) +def test_tflite_relu6(): + np.random.seed(0) + accel_type = "ethos-u55-128" + ifm_shape = (1, 12, 16, 8) + + @tf.function + def relu6(x): + return tf.nn.relu6(x) + + infra.compare_tvm_with_tflite( + relu6, + [ifm_shape], + accel_type, + enable_cascader=is_u55_accel_type(accel_type), + ranges=[(-1, 1)], + ) + + +def test_tflite_relu_n1_to_1(): + np.random.seed(0) + accel_type = "ethos-u55-128" + ifm_shape = (1, 12, 16, 8) + + @tf.function + def relu_n1_to_1(x): + """ + The specific pattern will be replaced into RELU_N1_TO_1 by tflite. + """ + return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0)) + + infra.compare_tvm_with_tflite( + relu_n1_to_1, + [ifm_shape], + accel_type, + enable_cascader=is_u55_accel_type(accel_type), + ranges=[(-1, 1)], + ) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) @pytest.mark.parametrize("ofm_channels", [32, 64]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) def test_tflite_fully_connected( accel_type, ifm_shape, ofm_channels, use_bias, - activation, + activation_function, ): np.random.seed(0) @@ -1152,8 +1173,8 @@ def fully_connected(x): x = tf.matmul(x, w) if use_bias: x = tf.nn.bias_add(x, bias) - if activation: - x = activation(x) + if activation_function: + x = tf.nn.relu(x) return x infra.compare_tvm_with_tflite( From f692260b2b8bdbb904ec22f10c212edb33957bdf Mon Sep 17 00:00:00 2001 From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Fri, 30 Dec 2022 17:23:46 +0400 Subject: [PATCH 3/5] remain test max_relu_n1_to_1 Remain relu_n1_to_1 for testing case when activation cannot be fused with the previous operation --- .../contrib/test_ethosu/test_codegen.py | 33 ++++--------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b318426c08c8..62372a05d39f 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1109,42 +1109,23 @@ def leaky_relu_func(x): ) -def test_tflite_relu6(): - np.random.seed(0) - accel_type = "ethos-u55-128" - ifm_shape = (1, 12, 16, 8) - - @tf.function - def relu6(x): - return tf.nn.relu6(x) - - infra.compare_tvm_with_tflite( - relu6, - [ifm_shape], - accel_type, - enable_cascader=is_u55_accel_type(accel_type), - ranges=[(-1, 1)], - ) - - def test_tflite_relu_n1_to_1(): np.random.seed(0) accel_type = "ethos-u55-128" ifm_shape = (1, 12, 16, 8) @tf.function - def relu_n1_to_1(x): - """ - The specific pattern will be replaced into RELU_N1_TO_1 by tflite. - """ - return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0)) + def max_relu_n1_to_1(lhs, rhs): + op = tf.math.maximum(lhs, rhs) + # The specific pattern will be replaced into RELU_N1_TO_1 by tflite. + return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0)) infra.compare_tvm_with_tflite( - relu_n1_to_1, - [ifm_shape], + max_relu_n1_to_1, + [ifm_shape, ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type), - ranges=[(-1, 1)], + ranges=[(-1, 1), (0, 2)], ) From 87b449d30342b932c4223370c061910823e748ca Mon Sep 17 00:00:00 2001 From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Wed, 11 Jan 2023 18:13:37 +0400 Subject: [PATCH 4/5] add relu6, relu_n1_to_1 test cases for Ethos-U, remove changes for BinaryElementwise operation --- python/tvm/relay/op/contrib/ethosu.py | 62 +++++------------ .../contrib/test_ethosu/test_codegen.py | 67 ++++++++++++++++--- .../contrib/test_ethosu/test_legalize.py | 10 --- 3 files changed, 74 insertions(+), 65 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index e1ec3172428b..c0f8e5e9708e 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -688,13 +688,15 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation: clip = None requantize = None - if str(current_call.op) == "clip": - clip = current_call - current_call = clip.args[0] - elif str(current_call.op) == "qnn.requantize": - requantize = current_call - clip = current_call.args[0] - current_call = clip.args[0] + if is_quantized_operation: + if str(current_call.op) == "clip": + clip = current_call + current_call = clip.args[0] + else: + if str(current_call.op) == "qnn.requantize": + requantize = current_call + clip = current_call.args[0] + current_call = clip.args[0] binary_op = current_call layout = "NHWC" @@ -927,9 +929,6 @@ def is_valid(self): [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] ): return False - # MIN with different scales is not supported on NPU. - if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32: - return False return True @@ -939,19 +938,10 @@ def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ minimum = is_op("minimum")(wildcard(), wildcard()) optional_min_clip = is_op("clip")(minimum) - return minimum | optional_min_clip - - -def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: - """ - This function creates the pattern for minimum with fused RELU activation. - """ - pattern = is_op("minimum")(wildcard(), wildcard()) - pattern = is_op("clip")(pattern) - pattern = is_op("qnn.requantize")( - pattern, is_constant(), is_constant(), is_constant(), is_constant() + optional_min_clip = is_op("qnn.requantize")( + optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant() ) - return pattern + return minimum | optional_min_clip class MaxParams(BinaryElementwiseParams): @@ -977,9 +967,6 @@ def is_valid(self): [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] ): return False - # MAX with different scales is not supported on NPU. - if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32: - return False return True @@ -989,19 +976,10 @@ def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: """ maximum = is_op("maximum")(wildcard(), wildcard()) optional_max_clip = is_op("clip")(maximum) - return maximum | optional_max_clip - - -def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern: - """ - This function creates the pattern for maximum with fused RELU activation. - """ - pattern = is_op("maximum")(wildcard(), wildcard()) - pattern = is_op("clip")(pattern) - pattern = is_op("qnn.requantize")( - pattern, is_constant(), is_constant(), is_constant(), is_constant() + optional_max_clip = is_op("qnn.requantize")( + optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant() ) - return pattern + return maximum | optional_max_clip class ShlParams(BinaryElementwiseParams): @@ -1842,21 +1820,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_mul_pattern(), lambda pat: MulParams(pat).is_valid(), ), - ( - MinParams.composite_name, - minimum_clip_requantize_pattern(), - lambda pat: MinParams(pat).is_valid(), - ), ( MinParams.composite_name, minimum_pattern(), lambda pat: MinParams(pat).is_valid(), ), - ( - MaxParams.composite_name, - maximum_clip_requantize_pattern(), - lambda pat: MaxParams(pat).is_valid(), - ), ( MaxParams.composite_name, maximum_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 62372a05d39f..afd5486f360d 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1109,23 +1109,74 @@ def leaky_relu_func(x): ) +# conv2d + relu_n1_to_1 is used because separate activation is not offloaded to NPU. def test_tflite_relu_n1_to_1(): np.random.seed(0) - accel_type = "ethos-u55-128" - ifm_shape = (1, 12, 16, 8) + accel_type = "ethos-u55-256" + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + padding = (1, 0, 1, 1) @tf.function - def max_relu_n1_to_1(lhs, rhs): - op = tf.math.maximum(lhs, rhs) + def conv2d_relu_n1_to_1(x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + ) # The specific pattern will be replaced into RELU_N1_TO_1 by tflite. return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0)) infra.compare_tvm_with_tflite( - max_relu_n1_to_1, - [ifm_shape, ifm_shape], + conv2d_relu_n1_to_1, + [ifm_shape], accel_type, - enable_cascader=is_u55_accel_type(accel_type), - ranges=[(-1, 1), (0, 2)], + enable_cascader=True, + ) + + +# conv2d + relu6 is used because separate activation is not offloaded to NPU. +def test_tflite_relu6(): + np.random.seed(0) + accel_type = "ethos-u55-256" + ifm_shape = (1, 55, 34, 3) + kernel_shape = (3, 2) + strides = (1, 1) + padding = (0, 0, 1, 1) + + @tf.function + def conv2d_relu6(x): + tf_strides = [1, strides[0], strides[1], 1] + op = tf.pad( + x, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + op, + weight, + strides=tf_strides, + padding="VALID", + ) + return tf.nn.relu6(op) + + infra.compare_tvm_with_tflite( + conv2d_relu6, + [ifm_shape], + accel_type, + enable_cascader=True, ) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 7641d9c4e6f0..9b4dd467ff9f 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -899,11 +899,6 @@ def verify(ext_func): elif operator_type == "MIN": rewriter = legalize.MinRewriter() pattern_table = [ - ( - ethosu.MinParams.composite_name, - ethosu.minimum_clip_requantize_pattern(), - lambda pat: ethosu.MinParams(pat).is_valid(), - ), ( ethosu.MinParams.composite_name, ethosu.minimum_pattern(), @@ -913,11 +908,6 @@ def verify(ext_func): elif operator_type == "MAX": rewriter = legalize.MaxRewriter() pattern_table = [ - ( - ethosu.MaxParams.composite_name, - ethosu.maximum_clip_requantize_pattern(), - lambda pat: ethosu.MaxParams(pat).is_valid(), - ), ( ethosu.MaxParams.composite_name, ethosu.maximum_pattern(), From 1332b95f42647169f83cbad6ba8a2bba47a253ef Mon Sep 17 00:00:00 2001 From: Alexey-Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Thu, 12 Jan 2023 11:56:15 +0400 Subject: [PATCH 5/5] remove pad operation from relu6, relu_n1_to_1 test cases --- tests/python/contrib/test_ethosu/test_codegen.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index afd5486f360d..c164b177c21c 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1116,20 +1116,14 @@ def test_tflite_relu_n1_to_1(): ifm_shape = (1, 55, 34, 3) kernel_shape = (3, 2) strides = (1, 1) - padding = (1, 0, 1, 1) @tf.function def conv2d_relu_n1_to_1(x): tf_strides = [1, strides[0], strides[1], 1] - op = tf.pad( - x, - [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], - "CONSTANT", - ) weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) op = tf.nn.conv2d( - op, + x, weight, strides=tf_strides, padding="VALID", @@ -1152,20 +1146,14 @@ def test_tflite_relu6(): ifm_shape = (1, 55, 34, 3) kernel_shape = (3, 2) strides = (1, 1) - padding = (0, 0, 1, 1) @tf.function def conv2d_relu6(x): tf_strides = [1, strides[0], strides[1], 1] - op = tf.pad( - x, - [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], - "CONSTANT", - ) weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) op = tf.nn.conv2d( - op, + x, weight, strides=tf_strides, padding="VALID",