diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 24966019db6d..f81354466664 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -74,8 +74,7 @@ def schedule_pool_arm_cpu(attrs, outs, target): and layout in ("NWC", "NHWC") ): return topi.arm_cpu.schedule_pool(outs, layout) - logger.warning("pool is not optimized for arm cpu.") - return topi.generic.schedule_pool(outs, layout) + return topi.x86.schedule_pool(outs, layout) def _get_padding_width(padding): diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 20dfe9670ab3..2bf1548d41d8 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -19,12 +19,14 @@ import pytest import numpy as np +from unittest.mock import MagicMock import tvm from tvm import relay from tvm import te from tvm.relay.testing import run_infer_type import tvm.testing +from tvm import topi @pytest.mark.parametrize( @@ -187,5 +189,34 @@ def test_dense(target, expected_valid_impl, expected_impl): assert selected_impl.name == expected_impl +@pytest.mark.parametrize( + "target,schedule_func", + [ + ("llvm -device=arm_cpu", topi.x86), + ("c -device=arm_cpu -mcpu=cortex-m55", topi.arm_cpu), + ], +) +def test_pool2d(target, schedule_func, monkeypatch): + target = tvm.target.Target(target) + + data_shape = (1, 2, 2, 4) + dtype = "float32" + + out = relay.nn.avg_pool2d(relay.var("data", shape=data_shape, dtype=dtype)) + placeholders = [te.placeholder(data_shape, dtype)] + + mock_schedule = MagicMock() + monkeypatch.setattr(schedule_func, "schedule_pool", mock_schedule) + + # Since pool does not use OpStrategy to determine the relevant schedule, + # we cannot simply check the schedule name that was selected with + # `select_implementation`. With this implementation of schedule selection, + # "pool.arm_cpu" will always be the schedule name, regardless of what schedule + # was selected. Instead, this test checks that the relevant schedule function + # is called when selecting the pooling from schedule from arm_cpu. + relay.op.strategy.arm_cpu.schedule_pool_arm_cpu(out.attrs, placeholders, target) + mock_schedule.assert_called() + + if __name__ == "__main__": tvm.testing.main()