diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index 8d78f67ac0b6..3e888de55fec 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -4,3 +4,4 @@ from . import depthwise_conv2d from . import conv2d_transpose from . import bitserial_conv2d +from . import injective diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py new file mode 100755 index 000000000000..09ea86cea3a3 --- /dev/null +++ b/topi/python/topi/arm_cpu/injective.py @@ -0,0 +1,37 @@ +# pylint: disable=invalid-name, unused-variable +"""Schedule for pooling operators""" +import tvm +from .. import generic + +@generic.schedule_injective.register(["arm_cpu"]) +def schedule_injective(outs): + """ARM CPU schedule for injective op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + x = outs[0] + if list(s[x].op.axis): + # do not vectorize for broadcast + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) + s[x].vectorize(ii) + tvm.schedule.AutoInlineInjective(s) + if len(s[x].op.axis) >= 4: + fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2]) + s[x].parallel(fused) + elif len(s[x].op.axis) >= 3: + fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) + s[x].parallel(fused) + elif len(s[x].op.axis) >= 2: + s[x].parallel(s[x].op.axis[0]) + return s diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 6926a3a2a73c..80966b15ddbe 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -5,6 +5,8 @@ import topi.testing import math +from common import get_all_backend + def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False): if layout == 'NCHW': @@ -40,7 +42,7 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) - for device in ['llvm', 'cuda', 'vulkan', 'nvptx']: + for device in get_all_backend(): check_device(device) def test_resize(): diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 8b0ba519736a..60f6e5655fff 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -5,6 +5,8 @@ import topi.testing import math +from common import get_all_backend + def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"): @@ -45,7 +47,7 @@ def check_device(device): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in ['llvm', 'cuda', 'vulkan', 'nvptx']: + for device in get_all_backend(): check_device(device) def test_upsampling():