Skip to content

Commit

Permalink
fix lint and division by zero
Browse files Browse the repository at this point in the history
Change-Id: Ia79d879399ad7f2d098fd4a0af5c29a89565133e
  • Loading branch information
lhutton1 committed Feb 8, 2024
1 parent 613e662 commit 61795d9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/tvm/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""
import numpy as np
import tvm
from tvm import te
from ..utils import is_empty_shape
Expand Down Expand Up @@ -69,7 +68,8 @@ def schedule_injective(outs):
if list(s[x].op.axis):
# do not vectorize for broadcast
dtype = "uint16" if x.dtype == "bfloat16" else x.dtype
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // (tvm.DataType(dtype).bits // 8))
itemsize = max(1, tvm.DataType(dtype).bits // 8)
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // itemsize)
s[x].vectorize(ii)
tvm.te.schedule.AutoInlineInjective(s)

Expand Down
20 changes: 16 additions & 4 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,10 @@ def test_forward_quantized_convolution():

if platform.machine() == "aarch64":
pytest.skip(
reason="Grouped convolution type inference error for `arm_cpu`. See https://github.com/apache/tvm/issues/16532"
reason=(
"Grouped convolution type inference error for `arm_cpu`. "
"See https://github.com/apache/tvm/issues/16532"
)
)

_test_tflite2_quantized_convolution(
Expand Down Expand Up @@ -1130,7 +1133,10 @@ def test_forward_quantized_depthwise_convolution():

if platform.machine() == "aarch64":
pytest.skip(
reason="Tensor intrinsic data type mismatch error. See https://github.com/apache/tvm/issues/16533"
reason=(
"Tensor intrinsic data type mismatch error. "
"See https://github.com/apache/tvm/issues/16533"
)
)

_test_tflite2_quantized_depthwise_convolution(
Expand Down Expand Up @@ -5207,7 +5213,10 @@ def test_forward_tflite_float16():

@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Fails during leagalization due to int16 datatype. See https://github.com/apache/tvm/issues/16535",
reason=(
"Fails during leagalization due to int16 datatype. "
"See https://github.com/apache/tvm/issues/16535",
),
)
def test_forward_mobilenet_int16():
"""Test int16 quantized model"""
Expand Down Expand Up @@ -5253,7 +5262,10 @@ def representative_dataset():

@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Fails during leagalization due to int16 datatype. See https://github.com/apache/tvm/issues/16535",
reason=(
"Fails during leagalization due to int16 datatype. "
"See https://github.com/apache/tvm/issues/16535",
),
)
def test_forward_ds_cnn_int16():
"""Test DS_CNN int16 quantized model"""
Expand Down

0 comments on commit 61795d9

Please sign in to comment.