Skip to content

Commit

Permalink
Expose FTVMInferCorrectLayout Python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
kueitang committed Aug 16, 2021
1 parent e12ddca commit f426b6b
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,23 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)


def register_infer_correct_layout(op_name, infer_layout=None, level=10):
"""Register infer op layout function for an op
Parameters
----------
op_name : str
The name of the operator
infer_layout: function (attrs: Attrs, inputs: List[Layout]) -> InferCorrectLayoutOutput
The function to infer correct layout
level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FInferCorrectLayout", infer_layout, level)


def register_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for an op
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/transform/infer_layout_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, missing-docstring, unused-import
"""
Relay infer correct layout pass.
"""
import tvm
from tvm.runtime import Object
from . import _ffi_api


@tvm._ffi.register_object("relay._transform.InferCorrectLayoutOutput")
class InferCorrectLayoutOutput(Object):
"""An output structure to hold results from FInferCorrectLayout calls."""

def __init__(self, input_layouts, output_layouts, new_attrs):
self.__init_handle_by_constructor__(
_ffi_api.InferCorrectLayoutOutput, input_layouts, output_layouts, new_attrs
)
7 changes: 7 additions & 0 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts) {

TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);

TVM_REGISTER_GLOBAL("relay._transform.InferCorrectLayoutOutput")
.set_body_typed([](Array<Layout> input_layouts, Array<Layout> output_layouts, Attrs new_attrs) {
return InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs);
});

TVM_REGISTER_NODE_TYPE(InferCorrectLayoutOutputNode);

} // namespace transform

} // namespace relay
Expand Down
9 changes: 9 additions & 0 deletions src/relay/transforms/infer_layout_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,16 @@ class InferCorrectLayoutOutputNode : public Object {
Array<Layout> input_layouts;
Array<Layout> output_layouts;
Attrs new_attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("input_layouts", &input_layouts);
v->Visit("output_layouts", &output_layouts);
v->Visit("new_attrs", &new_attrs);
}

TVM_DECLARE_BASE_OBJECT_INFO(InferCorrectLayoutOutputNode, Object);

static constexpr const char* _type_key = "relay._transform.InferCorrectLayoutOutput";
};

class InferCorrectLayoutOutput : public ObjectRef {
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay import transform, analysis
from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput
from tvm.relay.op import op as reg


def run_opt_pass(expr, passes):
Expand Down Expand Up @@ -1881,6 +1883,46 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_infer_correct_layout():
test_infer_correct_layout_flag = False

def before():
x = relay.var("x", shape=(1, 56, 56, 64))
weight = relay.var("weight", shape=(3, 3, 64, 64))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y

@reg.register_infer_correct_layout("nn.relu", level=11)
def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_types):
nonlocal test_infer_correct_layout_flag
test_infer_correct_layout_flag = True

ret = []
if new_in_layouts:
assert len(new_in_layouts) >= 1
ret = new_in_layouts[0]
else:
for i in range(len(old_in_layouts)):
if old_in_layouts[i]:
ret = old_in_layouts[i]
break
return InferCorrectLayoutOutput([ret], [ret], attrs)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
assert test_infer_correct_layout_flag == True


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
Expand Down Expand Up @@ -1914,3 +1956,4 @@ def expected():
test_conv_strided_slice_axes_convert_layout()
test_image_resize_convert_layout()
test_conv_image_resize_convert_layout()
test_infer_correct_layout()

0 comments on commit f426b6b

Please sign in to comment.