Skip to content

Commit

Permalink
[MetaSchedule] Fix for RewriteLayout + AllocateConst when the rank of…
Browse files Browse the repository at this point in the history
… the rewritten weight doesn't change (apache#13851)

[MetaSchedule] Fix for RewriteLayout + AllocateConst when the rank of
the rewritten weight doesn't change
  • Loading branch information
masahi authored and csullivan committed Feb 7, 2023
1 parent 92112c3 commit bce59fb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,26 @@ class ScheduleBuilder : public ExprVisitor {
<< "Only one layout-free constant is supported by RewriteLayout for now";
auto constant = const_collector.constants[0];

if (constant.Shape().size() == index_map->initial_indices.size()) {
auto is_constant_transformed = [index_map](runtime::NDArray c) {
if (c.Shape().size() != index_map->initial_indices.size()) {
return true;
}
size_t src_size_1d = 1;
Array<PrimExpr> orig_shape;
for (size_t i = 0; i < c.Shape().size(); ++i) {
src_size_1d *= c->shape[i];
orig_shape.push_back(PrimExpr(static_cast<int>((c->shape[i]))));
}
auto dst_shape = index_map->MapShape(orig_shape);
std::vector<int64_t> dst_shape_int;
size_t dst_size_1d = 1;
for (size_t i = 0; i < dst_shape.size(); ++i) {
dst_size_1d *= dst_shape[i].as<IntImmNode>()->value;
}
return src_size_1d != dst_size_1d;
};

if (!is_constant_transformed(constant)) {
// This is the first case, reached during the MetaScheduleLayoutRewrite pass.
//
// A layout-free constant having the same rank as an input to the index map
Expand Down
74 changes: 74 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,5 +871,79 @@ def test_disabled_pass_param():
pytest.fail("'disabled_pass' argument does not work")


def test_rewrite_layout_link_params_1x1_conv2d():
I, O, H, W = 32, 16, 256, 256
kH = kW = 1

strides = (1, 1)
padding = (0, 0)

data_shape = (1, H, W, I)
w_shape = (kH, kW, I, O)

data = relay.var("data", shape=data_shape, dtype="float32")
weight = relay.var("weight", shape=w_shape, dtype="float32")

conv = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(kH, kW),
channels=O,
padding=padding,
strides=strides,
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="float32",
)

mod = tvm.IRModule.from_expr(conv)

weight_np = np.random.randn(*w_shape).astype("float32")

params = {"weight": weight_np}

data_np = np.random.randn(*data_shape).astype("float32")

ref = (
relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm")
.evaluate()(*[data_np, weight_np])
.numpy()
)

link_params = True

target = "llvm --num-cores=4"

executor = relay.backend.Executor("graph", {"link-params": link_params})
mod = mod.with_attr("executor", executor)

with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
params=params,
work_dir=work_dir,
max_trials_global=8,
strategy="replay-trace",
)

lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

runtime.set_input("data", data_np)
runtime.run()

out = runtime.get_output(0).numpy()

np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit bce59fb

Please sign in to comment.