From 009e536c193760e51c8f9d04def400e3f3eda1b4 Mon Sep 17 00:00:00 2001 From: trahman-quic Date: Thu, 2 Jun 2022 14:13:50 -0500 Subject: [PATCH] Address comments --- .../slice_ops/add_subtract_multiply.py | 37 +++++++++++++--- python/tvm/topi/hexagon/utils.py | 42 ++++--------------- .../test_add_subtract_multiply.py | 4 +- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py b/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py index cb41cd078b28..f6e6d9671e5e 100755 --- a/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py +++ b/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py @@ -1,4 +1,26 @@ -# pylint: disable=invalid-name +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument, too-many-locals + +"""Compute and schedule for add, multiply, subtract slice op + +Please note the following assumptions made by the implementation: + +1) The inputs will be multiple of crouton layout except for the axis that needs broadcasting.""" from tvm import te from tvm import tir @@ -7,18 +29,22 @@ def add_broadcast_compute(A, B): + """Call the add op from topi""" return topi.add(A, B) def subtract_broadcast_compute(A, B): + """Call the subtract op from topi""" return topi.subtract(A, B) def multiply_broadcast_compute(A, B): + """Call the multiply op from topi""" return topi.multiply(A, B) -def get_layout(layout): +def get_2d_layout(layout): + """Get the 2d layout for transformation""" layout += "-2d" return get_layout_transform_fn(layout) @@ -26,6 +52,7 @@ def get_layout(layout): def STIR_broadcast_schedule( M, A, B, output_layout: str, input_A_layout: str, input_B_layout: str, op_name: str ): + """Schedule for input and output layout nhwc-8h2w32c2w considering broadcast""" func = te.create_prim_func([A, B, M]) s = tir.Schedule(func) @@ -35,14 +62,14 @@ def STIR_broadcast_schedule( block = s.get_block(block_dict[op_name]) if input_A_layout == "nhwc-8h2w32c2w": - input_A_transformed_layout = get_layout(input_A_layout) + input_A_transformed_layout = get_2d_layout(input_A_layout) s.transform_layout(block, buffer=("read", 0), index_map=input_A_transformed_layout) if input_B_layout == "nhwc-8h2w32c2w": - input_B_transformed_layout = get_layout(input_B_layout) + input_B_transformed_layout = get_2d_layout(input_B_layout) s.transform_layout(block, buffer=("read", 1), index_map=input_B_transformed_layout) - output_transformed_layout = get_layout(output_layout) + output_transformed_layout = get_2d_layout(output_layout) s.transform_layout(block, buffer=("write", 0), index_map=output_transformed_layout) n, h, w, c = s.get_loops(block) diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 450708f4b2a9..def86486dbe9 100755 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -15,62 +15,38 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name - +"""Common hexagon specific utilities""" from tvm import te def n11c_1024c_2d(n, h, w, c): + """Return index map for n11c_1024 2d layout""" return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] def n11c_1024c_1d(n, h, w, c): + """Return index map for n11c_1024 1d layout""" return [n, h, w, c // 1024, c % 1024] def nhwc_8h2w32c2w_2d(n, h, w, c): + """Return index map for nhwc_8h2w32c2w 2d layout""" return [n, h // 8, w // 4, c // 32, te.AXIS_SEPARATOR, h % 8, (w % 4) // 2, c % 32, w % 2] def nhwc_8h2w32c2w_1d(n, h, w, c): + """Return index map for nhwc_8h2w32c2w 1d layout""" return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2] def get_layout_transform_fn(layout): + """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": return nhwc_8h2w32c2w_2d if layout == "nhwc-8h2w32c2w-1d": return nhwc_8h2w32c2w_1d - elif layout == "n11c-1024c-2d": + if layout == "n11c-1024c-2d": return n11c_1024c_2d - elif layout == "n11c-1024c-1d": + if layout == "n11c-1024c-1d": return n11c_1024c_1d - else: - raise RuntimeError(f"Unexpected layout '{layout}'") - - -def apply_transform(s, block, block_index: int, buffer_type: str, layout: str): - """Apply transform layout on a buffer - - Parameters - ---------- - s: Schedule - block : BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - buffer_type : str - Type of the buffer index, "read" or "write" - layout : str - Layout of the buffer - """ - transform_fn = get_layout_transform_fn(layout) - if layout == "nhwc-8h2w32c2w-1d": - axis_separators = [4] - elif layout == "n11c-1024c-1d": - axis_separators = [2] - else: - raise RuntimeError(f"Unexpected layout '{layout}'") - - s.transform_layout(block, block_index, buffer_type, transform_fn) - if axis_separators: - s.set_axis_separator(block, block_index, buffer_type, axis_separators) + raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/test_add_subtract_multiply.py b/tests/python/contrib/test_hexagon/test_add_subtract_multiply.py index 7639c9bee497..1310e569a11e 100755 --- a/tests/python/contrib/test_hexagon/test_add_subtract_multiply.py +++ b/tests/python/contrib/test_hexagon/test_add_subtract_multiply.py @@ -20,7 +20,6 @@ import pytest import numpy as np -np.set_printoptions(threshold=np.inf) from tvm import te, topi import tvm.testing @@ -75,6 +74,7 @@ def transformed_expected_output_np(expected_output_np, output_layout): def hexagon_wrapper_allocation( device, layout, axis_separators, tensor_shape=None, data=None, transformed_data=None, dtype=None ): + """Input layout can either be nhwc-8h2w32c2w or nhwc""" if layout == "nhwc-8h2w32c2w": data_nd = allocate_hexagon_array( device, @@ -233,4 +233,4 @@ def test_transform( if __name__ == "__main__": - sys.exit(pytest.main(sys.argv)) + tvm.testing.main()