Skip to content

Commit

Permalink
Fix LayoutRewriter (#10118)
Browse files Browse the repository at this point in the history
* Fix layout pass

* add unit test

* fix lint

* fix lint

* fix lint
  • Loading branch information
lazycal authored Feb 3, 2022
1 parent 8ce1b6c commit e53cbe4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
12 changes: 0 additions & 12 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Respect input layout, if explicitly specified (for example, "NW").
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined()) {
return InferCorrectLayoutOutput({new_in_layouts[0], "NC"}, {"NC"}, attrs);
}
return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
}

Expand Down Expand Up @@ -283,14 +279,6 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
const Array<tvm::relay::Type>& old_in_types) {
auto params = attrs.as<DensePackAttrs>();
ICHECK(params);
// Respect input layout, if explicitly specified (for example, "NW").
// However, a packed layout such as "NC8c" is not supported by dense_pack op. For such cases,
// we insert a layout transform "NC8c" -> "NC".
// We do not expect to get a packed layout like "NW8w", which is not compatitble with "NC",
// since packing is always done on the "C" axis.
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && new_in_layouts[0].ndim() == 2) {
return InferCorrectLayoutOutput({new_in_layouts[0], params->weight_layout}, {"NC"}, attrs);
}
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
}

Expand Down
37 changes: 27 additions & 10 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
}

// old_in, new_in = state[inputs]
Array<Layout> old_in, old_out, new_in, new_out, new_in2;
// naming rule:
// old_in, new_in: the input layouts given by downstream node.
// old_in2, new_in2: the input layouts inferred by the current node.
Array<Layout> old_in, old_in2, old_out, new_in, new_out, new_in2;
for (auto inp : inputs) {
old_in.push_back(inp->old_layout);
new_in.push_back(inp->new_layout);
Expand All @@ -336,17 +339,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
InferCorrectLayoutOutput infer_out;
std::tie(infer_out, success) =
InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, types);
old_in = infer_out->input_layouts;
old_in2 = infer_out->input_layouts;
old_out = infer_out->output_layouts;
if (!success) {
return Expr(nullptr);
}
ICHECK_EQ(old_in.size(), new_in.size());
ICHECK_EQ(old_in2.size(), new_in.size());

// if new_in == 'undef': new_in = old_in
for (size_t i = 0; i < new_in.size(); ++i) {
if (!new_in[i].defined()) {
new_in.Set(i, old_in[i]);
Array<Layout> new_in_tmp = new_in; // for backward compatibility of InferCorrectLayouts
// if new_in_tmp == 'undef': new_in_tmp = old_in2
for (size_t i = 0; i < new_in_tmp.size(); ++i) {
if (!new_in_tmp[i].defined()) {
new_in_tmp.Set(i, old_in2[i]);
}
}

Expand All @@ -356,7 +360,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
success = false;
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types);
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in_tmp, old_in2, types);
new_in2 = infer_out->input_layouts;
new_out = infer_out->output_layouts;
if (!success) {
Expand All @@ -371,6 +375,17 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
ICHECK_EQ(new_in.size(), new_in2.size())
<< "The number of input nodes should keep the same during alter_op_layout";

auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_in, const Layout& old_in2,
const Layout& new_in, const Layout& new_in2) {
if (old_in2.Equals(old_in)) { // the two transforms can be fused to one
arg_item = memorizer.Transform(arg_item, new_in, new_in2);
} else {
if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in);
arg_item = memorizer.Transform(arg_item, old_in2, new_in2);
}
return arg_item;
};

// if (new_in != new_in2): insert transform (new_in -> new_in2)
Array<Expr> transformed_args;
size_t pt = 0;
Expand All @@ -380,12 +395,14 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
Array<Expr> transformed_tuple_arg;
transformed_tuple_arg.reserve(tuple_arg->fields.size());
for (auto arg_item : tuple_arg->fields) {
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
transformed_tuple_arg.push_back(
transform_layout(arg_item, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
transformed_args.push_back(
transform_layout(arg, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
pt++;
}
}
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,5 +1471,18 @@ def test_conv2d_reduce_channels():
relay.build(mod, params=params, target="llvm")


def test_axis_semantic_change():
x = relay.var("x", shape=(1, 1, 24, 48))
w1 = relay.const(np.random.uniform(size=(1, 1, 1, 1)))
w2 = relay.const(np.random.uniform(size=(1, 1, 1, 1)))
y = relay.nn.conv2d(x, w1, kernel_size=(1, 1), padding=(0, 0), channels=1)
y = relay.transpose(y, (0, 1, 3, 2))
z = relay.nn.conv2d(y, w2, kernel_size=(1, 1), padding=(0, 0), channels=1)
func = relay.Function([x], z)
mod = tvm.IRModule.from_expr(func)
with tvm.transform.PassContext(opt_level=3):
relay.build(mod, target="llvm")


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit e53cbe4

Please sign in to comment.