Skip to content

Commit

Permalink
Fix after MS API change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Oct 7, 2022
1 parent c38fc91 commit a25a3d3
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 9 deletions.
14 changes: 10 additions & 4 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@ def _normalize_params(
if isinstance(param, np.ndarray):
param = nd.array(param)
relay_params[name] = param
if executor is not None:

if executor is None:
executor = relay.backend.Executor("graph")

if mod.get_attr("executor") is None:
mod = mod.with_attr("executor", executor)
else:
executor = mod.get_attr("executor")

pass_config = dict(pass_config)
return mod, target, relay_params, pass_config, executor

Expand Down Expand Up @@ -384,8 +391,7 @@ def is_meta_schedule_dispatch_enabled() -> bool:
enabled: bool
Whether the meta schedule is enabled
"""
result = transform.PassContext.current().config.get(
return transform.PassContext.current().config.get(
"relay.backend.use_meta_schedule_dispatch",
0,
False,
)
return bool(result & 1)
5 changes: 2 additions & 3 deletions src/meta_schedule/postproc/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ Array<Postproc> Postproc::DefaultCUDATensorCore() {
Array<Postproc> Postproc::DefaultHexagon() {
return Array<Postproc>{
Postproc::DisallowDynamicLoop(),
Postproc::RewriteParallelVectorizeUnroll(), //
Postproc::RewriteParallelVectorizeUnroll(),
Postproc::RewriteReductionBlock(),
// TODO(masahi): Fix RewriteLayout for link-params=True case
// Postproc::RewriteLayout(),
Postproc::RewriteLayout(),
};
}

Expand Down
78 changes: 78 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Integration test for MetaSchedule"""
import tempfile
import numpy as np
import pytest
import tvm
Expand Down Expand Up @@ -489,5 +490,82 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)


def test_rewrite_layout_link_params():
I, O, H, W = 64, 64, 56, 56
kH = kW = 3

strides = (1, 1)
padding = (1, 1)

data_shape = (1, H, W, I)
w_shape = (kH, kW, I, O)
bias_shape = (1, 1, 1, O)

data = relay.var("data", shape=data_shape, dtype="float32")
weight = relay.var("weight1", shape=w_shape, dtype="float32")
bias = relay.var("bias", shape=bias_shape, dtype="float32")

conv = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(kH, kW),
channels=O,
padding=padding,
strides=strides,
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="float32",
)

mod = tvm.IRModule.from_expr(conv + bias)

weight_np = np.random.randn(*w_shape).astype("float32")
bias_np = np.random.randn(*bias_shape).astype("float32")

params = {"weight1": weight_np, "bias": bias_np}

data_np = np.random.randn(*data_shape).astype("float32")

ref = (
relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm")
.evaluate()(*[data_np, weight_np, bias_np])
.numpy()
)

link_params = True

target = "llvm --num-cores=4"

executor = relay.backend.Executor("graph", {"link-params": link_params})
mod = mod.with_attr("executor", executor)

with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
params=params,
work_dir=work_dir,
max_trials_global=4,
strategy="replay-trace",
)

lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

runtime.set_input("data", data_np)
runtime.run()

out = runtime.get_output(0).numpy()

np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
tvm.testing.main()
4 changes: 2 additions & 2 deletions tests/python/unittest/test_meta_schedule_vnni_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def f_check(lib, dev):
return relay_mod, params, f_check


@pytest.mark.skip("Requires cascadelake")
@tvm.testing.requires_cascadelake
def test_vnni_schedule_fn_database():
m, n, k = 1024, 1024, 1024
target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4")
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_vnni_schedule_fn_database():
f_check(lib, dev)


@pytest.mark.skip("Requires cascadelake")
@tvm.testing.requires_cascadelake
def test_vnni_schedule_fn_tune():
# pylint: disable=W0105
"""
Expand Down

0 comments on commit a25a3d3

Please sign in to comment.