From 61795d9bb2fdf3bd2c614542cd4bfac41dc987ba Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton@arm.com>
Date: Wed, 7 Feb 2024 22:18:47 +0000
Subject: [PATCH] fix lint and division by zero

Change-Id: Ia79d879399ad7f2d098fd4a0af5c29a89565133e
---
 python/tvm/topi/arm_cpu/injective.py         |  4 ++--
 tests/python/frontend/tflite/test_forward.py | 20 ++++++++++++++++----
 2 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/injective.py b/python/tvm/topi/arm_cpu/injective.py
index bac49b1824349..fbc071092503d 100644
--- a/python/tvm/topi/arm_cpu/injective.py
+++ b/python/tvm/topi/arm_cpu/injective.py
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name, unused-variable
 """Schedule for pooling operators"""
-import numpy as np
 import tvm
 from tvm import te
 from ..utils import is_empty_shape
@@ -69,7 +68,8 @@ def schedule_injective(outs):
     if list(s[x].op.axis):
         # do not vectorize for broadcast
         dtype = "uint16" if x.dtype == "bfloat16" else x.dtype
-        (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // (tvm.DataType(dtype).bits // 8))
+        itemsize = max(1, tvm.DataType(dtype).bits // 8)
+        (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // itemsize)
         s[x].vectorize(ii)
     tvm.te.schedule.AutoInlineInjective(s)
 
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 2e99f7c97dc23..26c8bc31af1c6 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1102,7 +1102,10 @@ def test_forward_quantized_convolution():
 
         if platform.machine() == "aarch64":
             pytest.skip(
-                reason="Grouped convolution type inference error for `arm_cpu`. See https://github.com/apache/tvm/issues/16532"
+                reason=(
+                    "Grouped convolution type inference error for `arm_cpu`. "
+                    "See https://github.com/apache/tvm/issues/16532"
+                )
             )
 
         _test_tflite2_quantized_convolution(
@@ -1130,7 +1133,10 @@ def test_forward_quantized_depthwise_convolution():
 
     if platform.machine() == "aarch64":
         pytest.skip(
-            reason="Tensor intrinsic data type mismatch error. See https://github.com/apache/tvm/issues/16533"
+            reason=(
+                "Tensor intrinsic data type mismatch error. "
+                "See https://github.com/apache/tvm/issues/16533"
+            )
         )
 
     _test_tflite2_quantized_depthwise_convolution(
@@ -5207,7 +5213,10 @@ def test_forward_tflite_float16():
 
 @pytest.mark.skipif(
     platform.machine() == "aarch64",
-    reason="Fails during leagalization due to int16 datatype. See https://github.com/apache/tvm/issues/16535",
+    reason=(
+        "Fails during leagalization due to int16 datatype. "
+        "See https://github.com/apache/tvm/issues/16535",
+    ),
 )
 def test_forward_mobilenet_int16():
     """Test int16 quantized model"""
@@ -5253,7 +5262,10 @@ def representative_dataset():
 
 @pytest.mark.skipif(
     platform.machine() == "aarch64",
-    reason="Fails during leagalization due to int16 datatype. See https://github.com/apache/tvm/issues/16535",
+    reason=(
+        "Fails during leagalization due to int16 datatype. "
+        "See https://github.com/apache/tvm/issues/16535",
+    ),
 )
 def test_forward_ds_cnn_int16():
     """Test DS_CNN int16 quantized model"""