From a25a3d3dee3651a4451db9cd58c4a021e17ca941 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Oct 2022 17:05:27 +0900 Subject: [PATCH] Fix after MS API change --- python/tvm/meta_schedule/relay_integration.py | 14 +++- src/meta_schedule/postproc/postproc.cc | 5 +- .../test_meta_schedule_relay_integration.py | 78 +++++++++++++++++++ .../test_meta_schedule_vnni_integration.py | 4 +- 4 files changed, 92 insertions(+), 9 deletions(-) diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index af992dd4bc8bb..b3d8d582ba2b7 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -93,8 +93,15 @@ def _normalize_params( if isinstance(param, np.ndarray): param = nd.array(param) relay_params[name] = param - if executor is not None: + + if executor is None: + executor = relay.backend.Executor("graph") + + if mod.get_attr("executor") is None: mod = mod.with_attr("executor", executor) + else: + executor = mod.get_attr("executor") + pass_config = dict(pass_config) return mod, target, relay_params, pass_config, executor @@ -384,8 +391,7 @@ def is_meta_schedule_dispatch_enabled() -> bool: enabled: bool Whether the meta schedule is enabled """ - result = transform.PassContext.current().config.get( + return transform.PassContext.current().config.get( "relay.backend.use_meta_schedule_dispatch", - 0, + False, ) - return bool(result & 1) diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index acc157e36e94d..acd783b1860d0 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -85,10 +85,9 @@ Array Postproc::DefaultCUDATensorCore() { Array Postproc::DefaultHexagon() { return Array{ Postproc::DisallowDynamicLoop(), - Postproc::RewriteParallelVectorizeUnroll(), // + Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), - // TODO(masahi): Fix RewriteLayout for link-params=True case - // Postproc::RewriteLayout(), + Postproc::RewriteLayout(), }; } diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index cf61df0c6ba8f..4047f44ac365a 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Integration test for MetaSchedule""" +import tempfile import numpy as np import pytest import tvm @@ -489,5 +490,82 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) +def test_rewrite_layout_link_params(): + I, O, H, W = 64, 64, 56, 56 + kH = kW = 3 + + strides = (1, 1) + padding = (1, 1) + + data_shape = (1, H, W, I) + w_shape = (kH, kW, I, O) + bias_shape = (1, 1, 1, O) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight1", shape=w_shape, dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + conv = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(kH, kW), + channels=O, + padding=padding, + strides=strides, + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="float32", + ) + + mod = tvm.IRModule.from_expr(conv + bias) + + weight_np = np.random.randn(*w_shape).astype("float32") + bias_np = np.random.randn(*bias_shape).astype("float32") + + params = {"weight1": weight_np, "bias": bias_np} + + data_np = np.random.randn(*data_shape).astype("float32") + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + link_params = True + + target = "llvm --num-cores=4" + + executor = relay.backend.Executor("graph", {"link-params": link_params}) + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy="replay-trace", + ) + + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + ) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index 2cd6098630560..710ea96d9f5c1 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -133,7 +133,7 @@ def f_check(lib, dev): return relay_mod, params, f_check -@pytest.mark.skip("Requires cascadelake") +@tvm.testing.requires_cascadelake def test_vnni_schedule_fn_database(): m, n, k = 1024, 1024, 1024 target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") @@ -164,7 +164,7 @@ def test_vnni_schedule_fn_database(): f_check(lib, dev) -@pytest.mark.skip("Requires cascadelake") +@tvm.testing.requires_cascadelake def test_vnni_schedule_fn_tune(): # pylint: disable=W0105 """