diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 89ef2708ff27..e0e05436a8e6 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -210,10 +210,6 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - // Respect input layout, if explicitly specified (for example, "NW"). - if (new_in_layouts.size() > 0 && new_in_layouts[0].defined()) { - return InferCorrectLayoutOutput({new_in_layouts[0], "NC"}, {"NC"}, attrs); - } return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs); } @@ -283,14 +279,6 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs, const Array& old_in_types) { auto params = attrs.as(); ICHECK(params); - // Respect input layout, if explicitly specified (for example, "NW"). - // However, a packed layout such as "NC8c" is not supported by dense_pack op. For such cases, - // we insert a layout transform "NC8c" -> "NC". - // We do not expect to get a packed layout like "NW8w", which is not compatitble with "NC", - // since packing is always done on the "C" axis. - if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && new_in_layouts[0].ndim() == 2) { - return InferCorrectLayoutOutput({new_in_layouts[0], params->weight_layout}, {"NC"}, attrs); - } return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs); } diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index 3dbf10e0611b..66689ae38f66 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -320,7 +320,10 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj } // old_in, new_in = state[inputs] - Array old_in, old_out, new_in, new_out, new_in2; + // naming rule: + // old_in, new_in: the input layouts given by downstream node. + // old_in2, new_in2: the input layouts inferred by the current node. + Array old_in, old_in2, old_out, new_in, new_out, new_in2; for (auto inp : inputs) { old_in.push_back(inp->old_layout); new_in.push_back(inp->new_layout); @@ -336,17 +339,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj InferCorrectLayoutOutput infer_out; std::tie(infer_out, success) = InferCorrectLayouts(ref_call, Array(nullptr), old_in, types); - old_in = infer_out->input_layouts; + old_in2 = infer_out->input_layouts; old_out = infer_out->output_layouts; if (!success) { return Expr(nullptr); } - ICHECK_EQ(old_in.size(), new_in.size()); + ICHECK_EQ(old_in2.size(), new_in.size()); - // if new_in == 'undef': new_in = old_in - for (size_t i = 0; i < new_in.size(); ++i) { - if (!new_in[i].defined()) { - new_in.Set(i, old_in[i]); + Array new_in_tmp = new_in; // for backward compatibility of InferCorrectLayouts + // if new_in_tmp == 'undef': new_in_tmp = old_in2 + for (size_t i = 0; i < new_in_tmp.size(); ++i) { + if (!new_in_tmp[i].defined()) { + new_in_tmp.Set(i, old_in2[i]); } } @@ -356,7 +360,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { success = false; - std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types); + std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in_tmp, old_in2, types); new_in2 = infer_out->input_layouts; new_out = infer_out->output_layouts; if (!success) { @@ -371,6 +375,17 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj ICHECK_EQ(new_in.size(), new_in2.size()) << "The number of input nodes should keep the same during alter_op_layout"; + auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_in, const Layout& old_in2, + const Layout& new_in, const Layout& new_in2) { + if (old_in2.Equals(old_in)) { // the two transforms can be fused to one + arg_item = memorizer.Transform(arg_item, new_in, new_in2); + } else { + if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in); + arg_item = memorizer.Transform(arg_item, old_in2, new_in2); + } + return arg_item; + }; + // if (new_in != new_in2): insert transform (new_in -> new_in2) Array transformed_args; size_t pt = 0; @@ -380,12 +395,14 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj Array transformed_tuple_arg; transformed_tuple_arg.reserve(tuple_arg->fields.size()); for (auto arg_item : tuple_arg->fields) { - transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); + transformed_tuple_arg.push_back( + transform_layout(arg_item, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt])); pt++; } transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg)); } else { - transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); + transformed_args.push_back( + transform_layout(arg, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt])); pt++; } } diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index ea7fe0bd7871..bb5d3e3ab6c2 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1471,5 +1471,18 @@ def test_conv2d_reduce_channels(): relay.build(mod, params=params, target="llvm") +def test_axis_semantic_change(): + x = relay.var("x", shape=(1, 1, 24, 48)) + w1 = relay.const(np.random.uniform(size=(1, 1, 1, 1))) + w2 = relay.const(np.random.uniform(size=(1, 1, 1, 1))) + y = relay.nn.conv2d(x, w1, kernel_size=(1, 1), padding=(0, 0), channels=1) + y = relay.transpose(y, (0, 1, 3, 2)) + z = relay.nn.conv2d(y, w2, kernel_size=(1, 1), padding=(0, 0), channels=1) + func = relay.Function([x], z) + mod = tvm.IRModule.from_expr(func) + with tvm.transform.PassContext(opt_level=3): + relay.build(mod, target="llvm") + + if __name__ == "__main__": pytest.main([__file__])