diff --git a/python/tvm/topi/arm_cpu/injective.py b/python/tvm/topi/arm_cpu/injective.py index bac49b1824349..fbc071092503d 100644 --- a/python/tvm/topi/arm_cpu/injective.py +++ b/python/tvm/topi/arm_cpu/injective.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name, unused-variable """Schedule for pooling operators""" -import numpy as np import tvm from tvm import te from ..utils import is_empty_shape @@ -69,7 +68,8 @@ def schedule_injective(outs): if list(s[x].op.axis): # do not vectorize for broadcast dtype = "uint16" if x.dtype == "bfloat16" else x.dtype - (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // (tvm.DataType(dtype).bits // 8)) + itemsize = max(1, tvm.DataType(dtype).bits // 8) + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // itemsize) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 2e99f7c97dc23..26c8bc31af1c6 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1102,7 +1102,10 @@ def test_forward_quantized_convolution(): if platform.machine() == "aarch64": pytest.skip( - reason="Grouped convolution type inference error for `arm_cpu`. See https://github.com/apache/tvm/issues/16532" + reason=( + "Grouped convolution type inference error for `arm_cpu`. " + "See https://github.com/apache/tvm/issues/16532" + ) ) _test_tflite2_quantized_convolution( @@ -1130,7 +1133,10 @@ def test_forward_quantized_depthwise_convolution(): if platform.machine() == "aarch64": pytest.skip( - reason="Tensor intrinsic data type mismatch error. See https://github.com/apache/tvm/issues/16533" + reason=( + "Tensor intrinsic data type mismatch error. " + "See https://github.com/apache/tvm/issues/16533" + ) ) _test_tflite2_quantized_depthwise_convolution( @@ -5207,7 +5213,10 @@ def test_forward_tflite_float16(): @pytest.mark.skipif( platform.machine() == "aarch64", - reason="Fails during leagalization due to int16 datatype. See https://github.com/apache/tvm/issues/16535", + reason=( + "Fails during leagalization due to int16 datatype. " + "See https://github.com/apache/tvm/issues/16535", + ), ) def test_forward_mobilenet_int16(): """Test int16 quantized model""" @@ -5253,7 +5262,10 @@ def representative_dataset(): @pytest.mark.skipif( platform.machine() == "aarch64", - reason="Fails during leagalization due to int16 datatype. See https://github.com/apache/tvm/issues/16535", + reason=( + "Fails during leagalization due to int16 datatype. " + "See https://github.com/apache/tvm/issues/16535", + ), ) def test_forward_ds_cnn_int16(): """Test DS_CNN int16 quantized model"""