Skip to content

Commit

Permalink
[microNPU] Add relu6 relu_n1_to_1 test cases for Ethos-U
Browse files Browse the repository at this point in the history
Tests are extended with cases with activations relu6 and relu_n1_to_1.
Does not fuse min and max operations with requantize if there are different scales as it is not supported on NPU.
  • Loading branch information
Aleksei-grovety committed Dec 19, 2022
1 parent 6161a8d commit 2a4b651
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 43 deletions.
62 changes: 47 additions & 15 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,15 +688,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation:
clip = None
requantize = None

if is_quantized_operation:
if str(current_call.op) == "clip":
clip = current_call
current_call = clip.args[0]
else:
if str(current_call.op) == "qnn.requantize":
requantize = current_call
clip = current_call.args[0]
current_call = clip.args[0]
if str(current_call.op) == "clip":
clip = current_call
current_call = clip.args[0]
elif str(current_call.op) == "qnn.requantize":
requantize = current_call
clip = current_call.args[0]
current_call = clip.args[0]
binary_op = current_call

layout = "NHWC"
Expand Down Expand Up @@ -929,6 +927,9 @@ def is_valid(self):
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
):
return False
# MIN with different scales is not supported on NPU.
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
return False
return True


Expand All @@ -938,12 +939,21 @@ def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
minimum = is_op("minimum")(wildcard(), wildcard())
optional_min_clip = is_op("clip")(minimum)
optional_min_clip = is_op("qnn.requantize")(
optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
)
return minimum | optional_min_clip


def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for minimum with fused RELU activation.
"""
pattern = is_op("minimum")(wildcard(), wildcard())
pattern = is_op("clip")(pattern)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
return pattern


class MaxParams(BinaryElementwiseParams):
"""
This class will parse a call to a ethosu.binary_elementwise Max composite function
Expand All @@ -967,6 +977,9 @@ def is_valid(self):
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
):
return False
# MAX with different scales is not supported on NPU.
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
return False
return True


Expand All @@ -976,12 +989,21 @@ def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
maximum = is_op("maximum")(wildcard(), wildcard())
optional_max_clip = is_op("clip")(maximum)
optional_max_clip = is_op("qnn.requantize")(
optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
)
return maximum | optional_max_clip


def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for maximum with fused RELU activation.
"""
pattern = is_op("maximum")(wildcard(), wildcard())
pattern = is_op("clip")(pattern)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
return pattern


class ShlParams(BinaryElementwiseParams):
"""
This class will parse a call to a ethosu.binary_elementwise Shl composite function
Expand Down Expand Up @@ -1820,11 +1842,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_mul_pattern(),
lambda pat: MulParams(pat).is_valid(),
),
(
MinParams.composite_name,
minimum_clip_requantize_pattern(),
lambda pat: MinParams(pat).is_valid(),
),
(
MinParams.composite_name,
minimum_pattern(),
lambda pat: MinParams(pat).is_valid(),
),
(
MaxParams.composite_name,
maximum_clip_requantize_pattern(),
lambda pat: MaxParams(pat).is_valid(),
),
(
MaxParams.composite_name,
maximum_pattern(),
Expand Down
74 changes: 46 additions & 28 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"]


def relu_n1_to_1(x):
"""
The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
"""
return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))


ACTIVATIONS = [None, tf.nn.relu, tf.nn.relu6, relu_n1_to_1]


def is_u55_accel_type(accel_type):
return "u55" in accel_type

Expand All @@ -46,7 +56,7 @@ def is_u55_accel_type(accel_type):
@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@pytest.mark.parametrize("activation", ["NONE", "RELU"])
@pytest.mark.parametrize("activation", ACTIVATIONS)
def test_ethosu_conv2d_single(
ifm_shape,
kernel_shape,
Expand All @@ -72,8 +82,8 @@ def conv2d(x):
padding=padding,
dilations=dilation,
)
if activation == "RELU":
op = tf.nn.relu(op)
if activation:
op = activation(op)
return op

infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type)
Expand Down Expand Up @@ -114,7 +124,7 @@ def conv2d(x):
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"])
@pytest.mark.parametrize("activation", ["NONE", "RELU"])
@pytest.mark.parametrize("activation", ACTIVATIONS)
def test_ethosu_conv2d_double(
ifm_shape,
kernel_shape,
Expand Down Expand Up @@ -150,22 +160,28 @@ def conv2d_double(x):
padding=padding,
dilations=dilation,
)
if activation == "RELU":
op2 = tf.nn.relu(op2)
if activation:
op2 = activation(op2)
return op2

infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)


@pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)])
def test_out_of_range_scaling(weight_min, weight_max):
# relu6 and relu_n1_to_1 operations are excluded from activations since tflite results are different.
# In the tflite model, a rather large scale is generated, so in some cases in tflite result is -128 in ethosu 127.
@pytest.mark.parametrize("activation", [None, tf.nn.relu])
def test_out_of_range_scaling(
weight_min,
weight_max,
activation,
):
np.random.seed(0)
ifm_shape = (1, 6, 6, 2)
strides = (1, 1)
kernel_shape = (1, 1)
dilation = (1, 1)
padding = "SAME"
activation = "RELU"
accel_type = "ethos-u55-128"

@tf.function
Expand All @@ -186,8 +202,8 @@ def conv_invalid_scale(x):
padding=padding,
dilations=dilation,
)
if activation == "RELU":
op = tf.nn.relu(op)
if activation:
op = activation(op)
return op

infra.compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type)
Expand All @@ -196,19 +212,20 @@ def conv_invalid_scale(x):
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
@pytest.mark.parametrize(
"kernel_shape, activation_function",
[((3, 3), "RELU"), ((1, 2), "NONE")],
"kernel_shape",
[(3, 3), (1, 2)],
)
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))])
@pytest.mark.parametrize("activation", ACTIVATIONS)
def test_tflite_depthwise_conv2d(
accel_type,
ifm_shape,
kernel_shape,
padding,
strides,
dilation,
activation_function,
activation,
):
np.random.seed(0)

Expand All @@ -221,8 +238,8 @@ def depthwise_conv2d(x):
op = tf.nn.depthwise_conv2d(
x, weight, strides=tf_strides, padding=padding, dilations=dilation
)
if activation_function == "RELU":
op = tf.nn.relu(op)
if activation:
op = activation(op)
return op

infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
Expand Down Expand Up @@ -265,17 +282,18 @@ def depthwise_conv2d(x):
@pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
@pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]])
@pytest.mark.parametrize(
"pool_shape, strides, activation_function, padding",
[([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")],
"pool_shape, strides, padding",
[([1, 2], [1, 2], "SAME"), ([2, 3], [2, 3], "VALID")],
)
@pytest.mark.parametrize("activation", ACTIVATIONS)
def test_ethosu_pooling(
accel_type,
ifm_shape,
pooling_type,
strides,
pool_shape,
activation_function,
padding,
activation,
):
np.random.seed(0)

Expand All @@ -285,8 +303,8 @@ def pooling(x):
op = tf.nn.max_pool(x, pool_shape, strides, padding)
elif pooling_type == "AVG":
op = tf.nn.avg_pool(x, pool_shape, strides, padding)
if activation_function == "RELU":
op = tf.nn.relu(op)
if activation:
op = activation(op)
return op

infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type)
Expand All @@ -303,13 +321,13 @@ def pooling(x):
([1, 4, 4], [4, 1]),
],
)
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
@pytest.mark.parametrize("activation", ACTIVATIONS)
def test_ethosu_binary_elementwise(
accel_type,
operator_type,
ifm_shape,
ifm2_shape,
activation_function,
activation,
):
np.random.seed(0)

Expand All @@ -325,8 +343,8 @@ def binary_elementwise(lhs, rhs):
op = tf.math.minimum(lhs, rhs)
elif operator_type == "MAX":
op = tf.math.maximum(lhs, rhs)
if activation_function == "RELU":
op = tf.nn.relu(op)
if activation:
op = activation(op)
return op

infra.compare_tvm_with_tflite(
Expand Down Expand Up @@ -1113,13 +1131,13 @@ def leaky_relu_func(x):
@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
@pytest.mark.parametrize("ofm_channels", [32, 64])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("activation_function", ["RELU", "NONE"])
@pytest.mark.parametrize("activation", ACTIVATIONS)
def test_tflite_fully_connected(
accel_type,
ifm_shape,
ofm_channels,
use_bias,
activation_function,
activation,
):
np.random.seed(0)

Expand All @@ -1134,8 +1152,8 @@ def fully_connected(x):
x = tf.matmul(x, w)
if use_bias:
x = tf.nn.bias_add(x, bias)
if activation_function:
x = tf.nn.relu(x)
if activation:
x = activation(x)
return x

infra.compare_tvm_with_tflite(
Expand Down
10 changes: 10 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,11 @@ def verify(ext_func):
elif operator_type == "MIN":
rewriter = legalize.MinRewriter()
pattern_table = [
(
ethosu.MinParams.composite_name,
ethosu.minimum_clip_requantize_pattern(),
lambda pat: ethosu.MinParams(pat).is_valid(),
),
(
ethosu.MinParams.composite_name,
ethosu.minimum_pattern(),
Expand All @@ -908,6 +913,11 @@ def verify(ext_func):
elif operator_type == "MAX":
rewriter = legalize.MaxRewriter()
pattern_table = [
(
ethosu.MaxParams.composite_name,
ethosu.maximum_clip_requantize_pattern(),
lambda pat: ethosu.MaxParams(pat).is_valid(),
),
(
ethosu.MaxParams.composite_name,
ethosu.maximum_pattern(),
Expand Down

0 comments on commit 2a4b651

Please sign in to comment.