Skip to content

Commit

Permalink
[Relay][Strategy] Use x86 pool schedules for arm_cpu
Browse files Browse the repository at this point in the history
Similar to apache#15470, x86 schedules are
used in place of generic schedules to improve performance.

Since the pooling strategy does not use `OpStrategy`, mocking is used
to ensure the relevant `schedule_pool` function is called when lowing a
Relay pooling operation with respect to a given target.

Change-Id: I782fe00e29f9c9cf41b3405d33a82a79cd85a99b
  • Loading branch information
lhutton1 committed Aug 8, 2023
1 parent 8cadd1f commit fe2e63f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
3 changes: 1 addition & 2 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def schedule_pool_arm_cpu(attrs, outs, target):
and layout in ("NWC", "NHWC")
):
return topi.arm_cpu.schedule_pool(outs, layout)
logger.warning("pool is not optimized for arm cpu.")
return topi.generic.schedule_pool(outs, layout)
return topi.x86.schedule_pool(outs, layout)


def _get_padding_width(padding):
Expand Down
35 changes: 35 additions & 0 deletions tests/python/relay/strategy/test_select_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

import pytest
import numpy as np
from unittest.mock import MagicMock

import tvm
from tvm import relay
from tvm import te
from tvm.relay.testing import run_infer_type
import tvm.testing
from tvm import topi


@pytest.mark.parametrize(
Expand Down Expand Up @@ -187,5 +189,38 @@ def test_dense(target, expected_valid_impl, expected_impl):
assert selected_impl.name == expected_impl


@pytest.mark.parametrize(
"target,schedule_func",
[
("llvm -device=arm_cpu", topi.x86),
("c -device=arm_cpu -mcpu=cortex-m55", topi.arm_cpu),
],
)
def test_pool2d(target, schedule_func, monkeypatch):
target = tvm.target.Target(target)

data_shape = (1, 2, 2, 4)
dtype = "float32"

out = relay.nn.avg_pool2d(relay.var("data", shape=data_shape, dtype=dtype))
out = relay.Function(relay.analysis.free_vars(out), out)
out = tvm.IRModule.from_expr(out)

# Since pool does not use OpStrategy to determine the relevant schedule,
# we cannot simply check the schedule name that was selected with
# `select_implementation`. With this implementation of schedule selection,
# "pool.arm_cpu" will always be the schedule name, regardless of what schedule
# was selected. Instead, this test checks that the relevant schedule function
# is called when building the module.
mock_schedule = MagicMock()
mock_schedule.side_effect = lambda outs, layout: topi.generic.schedule_pool(outs, layout)
monkeypatch.setattr(schedule_func, "schedule_pool", mock_schedule)

with target:
tvm.relay.build(out, target=target)

mock_schedule.assert_called()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit fe2e63f

Please sign in to comment.