From bc75487032aee5d0fa2bb1d673f9dcd8c7e43810 Mon Sep 17 00:00:00 2001 From: Tasmia Rahman <89925728+trahman-quic@users.noreply.github.com> Date: Tue, 21 Jun 2022 13:12:08 -0500 Subject: [PATCH] [HEXAGON] Slice ops added - add, subtract, multiply (#11529) * [UPSTREAM][HEXAGON] Slice ops added - add, subtract, multiply * Change to v68 * Change transform_numpy function call * Do not disbale pylint errors and fix them * Fix variable names * Move the test file to topi * Resolve conflict * Modify init --- python/tvm/topi/hexagon/slice_ops/__init__.py | 3 +- .../slice_ops/add_subtract_multiply.py | 87 +++++++ .../topi/test_add_subtract_multiply.py | 229 ++++++++++++++++++ 3 files changed, 317 insertions(+), 2 deletions(-) create mode 100755 python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py create mode 100755 tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index b52d410676af..70531c629e4c 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -17,6 +17,5 @@ """ Computes and Schedules for Hexagon slice ops. """ -# pylint: disable=wildcard-import - from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule +from .add_subtract_multiply import * diff --git a/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py b/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py new file mode 100755 index 000000000000..86b6adb997cb --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +"""Compute and schedule for add, multiply, subtract slice op + +Please note the following assumptions made by the implementation: + +1) The inputs will be multiple of crouton layout except for the axis that needs broadcasting.""" + +from tvm import te +from tvm import tir +from tvm import topi +from ..utils import get_layout_transform_fn + + +def add_broadcast_compute(input_a, input_b): + """Call the add op from topi""" + return topi.add(input_a, input_b) + + +def subtract_broadcast_compute(input_a, input_b): + """Call the subtract op from topi""" + return topi.subtract(input_a, input_b) + + +def multiply_broadcast_compute(input_a, input_b): + """Call the multiply op from topi""" + return topi.multiply(input_a, input_b) + + +def tir_broadcast_schedule( + out_m, + input_a, + input_b, + output_layout: str, + input_a_layout: str, + input_b_layout: str, + op_name: str, +): + """Schedule for input and output layout nhwc-8h2w32c2w-2d considering broadcast""" + func = te.create_prim_func([input_a, input_b, out_m]) + + s = tir.Schedule(func) + + block_dict = {"add": "T_add", "subtract": "T_subtract", "multiply": "T_multiply"} + + block = s.get_block(block_dict[op_name]) + + if input_a_layout == "nhwc-8h2w32c2w-2d": + input_a_transformed_layout = get_layout_transform_fn(input_a_layout) + s.transform_layout(block, buffer=("read", 0), index_map=input_a_transformed_layout) + + if input_b_layout == "nhwc-8h2w32c2w-2d": + input_b_transformed_layout = get_layout_transform_fn(input_b_layout) + s.transform_layout(block, buffer=("read", 1), index_map=input_b_transformed_layout) + + output_transformed_layout = get_layout_transform_fn(output_layout) + s.transform_layout(block, buffer=("write", 0), index_map=output_transformed_layout) + + n, h, w, c = s.get_loops(block) + + h_o, h_i = s.split(h, [None, 8]) + w_o, w_i = s.split(w, [None, 4]) + c_o, c_i = s.split(c, [None, 32]) + wio, wii = s.split(w_i, [None, 2]) + + s.reorder(n, h_o, w_o, c_o, h_i, wio, c_i, wii) + + fused = s.fuse(c_i, wii) + s.vectorize(fused) + + return s diff --git a/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py b/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py new file mode 100755 index 000000000000..fa2d9797a882 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import pytest +import numpy as np + +from tvm import te, topi + +import tvm.testing +from tvm.topi import testing +from tvm.contrib.hexagon.build import HexagonLauncher +import tvm.topi.hexagon.slice_ops as sl +from ..infrastructure import allocate_hexagon_array, transform_numpy + + +@tvm.testing.fixture +def expected_output_np(input_np_A, input_np_B, op_name): + if op_name == "add": + out_ref = np.add(input_np_A, input_np_B) + elif op_name == "subtract": + out_ref = np.subtract(input_np_A, input_np_B) + elif op_name == "multiply": + out_ref = np.multiply(input_np_A, input_np_B) + return out_ref + + +@tvm.testing.fixture +def input_np_A(input_shape_A, dtype): + return np.random.random(input_shape_A).astype(dtype) + + +@tvm.testing.fixture +def input_np_B(input_shape_B, dtype): + return np.random.random(input_shape_B).astype(dtype) + + +@tvm.testing.fixture +def transformed_input_np_A(input_np_A, input_A_layout): + return transform_numpy(input_np_A, "nhwc", input_A_layout) + + +@tvm.testing.fixture +def transformed_input_np_B(input_np_B, input_B_layout): + return transform_numpy(input_np_B, "nhwc", input_B_layout) + + +@tvm.testing.fixture +def transformed_expected_output_np(expected_output_np, output_layout): + return transform_numpy(expected_output_np, "nhwc", output_layout) + + +def hexagon_wrapper_allocation( + device, layout, axis_separators, tensor_shape=None, data=None, transformed_data=None, dtype=None +): + """Input layout can either be nhwc-8h2w32c2w-2d or nhwc""" + if layout == "nhwc-8h2w32c2w-2d": + data_nd = allocate_hexagon_array( + device, + tensor_shape=tensor_shape, + data=transformed_data, + dtype=dtype, + axis_separators=axis_separators, + mem_scope="global.vtcm", + ) + elif layout == "nhwc": + data_nd = allocate_hexagon_array( + device, + data=data, + ) + return data_nd + + +class TestAddSubtractMultiplyBroadcast2d: + ( + input_shape_A, + input_shape_B, + input_A_layout, + input_B_layout, + output_layout, + dtype, + ) = tvm.testing.parameters( + # no broadcast needed - short input + ( + [1, 8, 4, 32], + [1, 8, 4, 32], + "nhwc-8h2w32c2w-2d", + "nhwc-8h2w32c2w-2d", + "nhwc-8h2w32c2w-2d", + "float16", + ), + # no broadcast needed - large input + ( + [1, 56, 64, 128], + [1, 56, 64, 128], + "nhwc-8h2w32c2w-2d", + "nhwc-8h2w32c2w-2d", + "nhwc-8h2w32c2w-2d", + "float16", + ), + # one input needs broadcast + ( + [1, 56, 64, 128], + [1, 1, 64, 1], + "nhwc-8h2w32c2w-2d", + "nhwc", + "nhwc-8h2w32c2w-2d", + "float16", + ), + # Both input needs broadcast + ( + [1, 56, 1, 128], + [1, 1, 64, 1], + "nhwc", + "nhwc", + "nhwc-8h2w32c2w-2d", + "float16", + ), + # One axis in one input needs broadcast + ( + [1, 56, 20, 128], + [1, 56, 20, 1], + "nhwc-8h2w32c2w-2d", + "nhwc", + "nhwc-8h2w32c2w-2d", + "float16", + ), + ) + + op_name = tvm.testing.parameter("add", "subtract", "multiply") + + @tvm.testing.requires_hexagon + def test_transform( + self, + dtype, + input_shape_A, + input_shape_B, + input_np_A, + input_np_B, + transformed_input_np_A, + transformed_input_np_B, + expected_output_np, + transformed_expected_output_np, + hexagon_session, + output_layout, + input_A_layout, + input_B_layout, + op_name, + ): + target_hexagon = tvm.target.hexagon("v68") + A = te.placeholder(input_shape_A, name="A", dtype=dtype) + B = te.placeholder(input_shape_B, name="B", dtype=dtype) + if op_name == "add": + M = sl.add_broadcast_compute(A, B) + elif op_name == "subtract": + M = sl.subtract_broadcast_compute(A, B) + elif op_name == "multiply": + M = sl.multiply_broadcast_compute(A, B) + + tir_schedule = sl.tir_broadcast_schedule( + M, A, B, output_layout, input_A_layout, input_B_layout, op_name + ) + sch = tir_schedule.mod + + input_axis_separator = [4] + if output_layout == "nhwc-8h2w32c2w-2d": + output_axis_separator = [4] + else: + raise RuntimeError(f"Unexpected layout '{output_layout}'") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): + func = tvm.build( + sch, + [A, B, M], + tvm.target.Target(target_hexagon, host=target_hexagon), + name="slice_op_with_transform", + ) + + output_shape = expected_output_np.shape + + A_data_nd = hexagon_wrapper_allocation( + hexagon_session.device, + layout=input_A_layout, + data=input_np_A, + transformed_data=transformed_input_np_A, + axis_separators=input_axis_separator, + ) + B_data_nd = hexagon_wrapper_allocation( + hexagon_session.device, + layout=input_B_layout, + data=input_np_B, + transformed_data=transformed_input_np_B, + axis_separators=input_axis_separator, + ) + M_data_nd = hexagon_wrapper_allocation( + hexagon_session.device, + layout=output_layout, + tensor_shape=transformed_expected_output_np.shape, + axis_separators=output_axis_separator, + dtype=dtype, + ) + + mod = hexagon_session.load_module(func) + mod(A_data_nd, B_data_nd, M_data_nd) + + b, h, w, c = output_shape + # convert nd to np and reshape to fixed chunk size layout + if output_layout == "nhwc-8h2w32c2w-2d": + M_data_np = M_data_nd.numpy().reshape([b, h // 8, w // 4, c // 32, 8, 2, 32, 2]) + + np.testing.assert_allclose(transformed_expected_output_np, M_data_np, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + tvm.testing.main()