diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 0d90a5cdeafa3..bbaffc469321f 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -341,6 +341,23 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level) +def register_infer_correct_layout(op_name, infer_layout=None, level=10): + """Register infer op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + infer_layout: function (attrs: Attrs, inputs: List[Layout]) -> InferCorrectLayoutOutput + The function to infer correct layout + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FInferCorrectLayout", infer_layout, level) + + def register_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for an op diff --git a/python/tvm/relay/transform/infer_layout_utils.py b/python/tvm/relay/transform/infer_layout_utils.py new file mode 100755 index 0000000000000..2dc0d25e2dcd1 --- /dev/null +++ b/python/tvm/relay/transform/infer_layout_utils.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, missing-docstring, unused-import +""" +Relay infer correct layout pass. +""" +import tvm +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("relay._transform.InferCorrectLayoutOutput") +class InferCorrectLayoutOutput(Object): + """An output structure to hold results from FInferCorrectLayout calls.""" + + def __init__(self, input_layouts, output_layouts, new_attrs): + self.__init_handle_by_constructor__( + _ffi_api.InferCorrectLayoutOutput, input_layouts, output_layouts, new_attrs + ) diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index a29fdeb378322..e74ea01158575 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -155,6 +155,13 @@ Pass ConvertLayout(const Map>& desired_layouts) { TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); +TVM_REGISTER_GLOBAL("relay._transform.InferCorrectLayoutOutput") + .set_body_typed([](Array input_layouts, Array output_layouts, Attrs new_attrs) { + return InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs); + }); + +TVM_REGISTER_NODE_TYPE(InferCorrectLayoutOutputNode); + } // namespace transform } // namespace relay diff --git a/src/relay/transforms/infer_layout_utils.h b/src/relay/transforms/infer_layout_utils.h index 5aedb9ff75d45..76d6aa646f4c8 100644 --- a/src/relay/transforms/infer_layout_utils.h +++ b/src/relay/transforms/infer_layout_utils.h @@ -97,7 +97,16 @@ class InferCorrectLayoutOutputNode : public Object { Array input_layouts; Array output_layouts; Attrs new_attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("input_layouts", &input_layouts); + v->Visit("output_layouts", &output_layouts); + v->Visit("new_attrs", &new_attrs); + } + TVM_DECLARE_BASE_OBJECT_INFO(InferCorrectLayoutOutputNode, Object); + + static constexpr const char* _type_key = "relay._transform.InferCorrectLayoutOutput"; }; class InferCorrectLayoutOutput : public ObjectRef { diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index fafab3ee35843..2bbef47d3100f 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -21,6 +21,8 @@ from tvm import relay from tvm.relay.op import register_alter_op_layout from tvm.relay import transform, analysis +from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput +from tvm.relay.op import op as reg def run_opt_pass(expr, passes): @@ -1881,6 +1883,46 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_infer_correct_layout(): + test_infer_correct_layout_flag = False + + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + @reg.register_infer_correct_layout("nn.relu", level=11) + def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_types): + nonlocal test_infer_correct_layout_flag + test_infer_correct_layout_flag = True + + ret = [] + if new_in_layouts: + assert len(new_in_layouts) >= 1 + ret = new_in_layouts[0] + else: + for i in range(len(old_in_layouts)): + if old_in_layouts[i]: + ret = old_in_layouts[i] + break + return InferCorrectLayoutOutput([ret], [ret], attrs) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + assert test_infer_correct_layout_flag == True + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1914,3 +1956,4 @@ def expected(): test_conv_strided_slice_axes_convert_layout() test_image_resize_convert_layout() test_conv_image_resize_convert_layout() + test_infer_correct_layout()