From bb2a8fa1e46cee5188a1a25e654f2212637ba157 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Wed, 18 May 2022 19:38:07 +0530 Subject: [PATCH] [Topi] [Hexagon] Conv2d slice op initial version --- python/tvm/topi/hexagon/slice_ops/__init__.py | 1 + python/tvm/topi/hexagon/slice_ops/conv2d.py | 242 +++++++++++++ python/tvm/topi/hexagon/utils.py | 14 + .../test_hexagon/topi/test_conv2d_slice.py | 339 ++++++++++++++++++ 4 files changed, 596 insertions(+) create mode 100644 python/tvm/topi/hexagon/slice_ops/conv2d.py create mode 100755 tests/python/contrib/test_hexagon/topi/test_conv2d_slice.py diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 5b3ef530b0c07..3eafd78f5c83d 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -20,3 +20,4 @@ from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule from .add_subtract_multiply import * from .softmax_slice import * +from .conv2d import * diff --git a/python/tvm/topi/hexagon/slice_ops/conv2d.py b/python/tvm/topi/hexagon/slice_ops/conv2d.py new file mode 100644 index 0000000000000..439fd80648f9d --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/conv2d.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long + +"""Hexagon slice conv2d compute and schedule""" +import typing + +import tvm +from tvm import te + +from ..utils import get_layout_transform_fn + + +def conv2d_compute( + activations: te.Tensor, + weights: te.Tensor, + out_shape: typing.Tuple, + stride: typing.Tuple, + dilation: typing.Tuple, + dtype: str, + output_name: str, + weights_width_reversed: bool = True, +) -> te.Tensor: + """Compute for slice conv2d op for hexagon. + + This op makes the following assumptions: + 1. This op is written for a sliced convolution with 2d physical buffers + 2. The input activations is assumed to be in NHWC layout and filter is in HWIO layout + 3. Grouped convolutions are not supported. and there will be a separate compute definition for depthwise convolution + 4. In order to get grouped convolutions, it is assumed that the op will be sliced according to the groups and multiple calls to this compute would be placed. + + + Parameters + ---------- + activations : te.Tensor + Input activations padded for inner dimension size + weights : te.Tensor + Weights without dilation + out_shape : typing.Tuple + The logical output shape without considering input padding + stride : typing.Tuple + stride + dilation : typing.Tuple + dilation + dtype : str + dtype + output_name : str + The name to be given to output. This would become the block name for the corresponding STIR compute + weights_width_reversed : bool + The width axis of weights are expected in reverse order if weights_width_reversed is True + + Returns + ------- + output : te.Tensor + Output of applying 2D convolution of Weights on Input + """ + + filt_shape = weights.shape + + reduce_channel = tvm.te.reduce_axis((0, filt_shape[2]), name="reduce_channel") + reduce_height = tvm.te.reduce_axis((0, filt_shape[0]), name="reduce_height") + reduce_width = tvm.te.reduce_axis((0, filt_shape[1]), name="reduce_width") + stride_height, stride_width = stride + dilation_height, dilation_width = dilation + + if weights_width_reversed: + weights_width_var = filt_shape[1] - reduce_width - 1 + else: + weights_width_var = reduce_width + + output = tvm.te.compute( + out_shape, + lambda n, h, w, c: tvm.te.sum( + ( + activations[ + n, + h * stride_height + reduce_height * dilation_height, + w * stride_width + reduce_width * dilation_width, + reduce_channel, + ] + * weights[reduce_height, weights_width_var, reduce_channel, c] + ).astype(dtype), + axis=[reduce_channel, reduce_height, reduce_width], + ), + name=output_name, + ) + return output + + +def conv2d_te_schedule( + out: te.Tensor, + ins: typing.List[te.Tensor], + transform_activation_layout: str, + transform_weights_layout: str, + transform_output_layout: str, +) -> te.Schedule: + """TE Schedule for the sliced conv2d op + + This schedule makes the following assumptions: + 1. There is only one output tensor + 2. The activations and weights have specific layouts defined by the last 2 arguments + 3. All transformation functions are expected to be a bijection for now + + Parameters + ---------- + out : te.Tensor + The output tensor returned by a call to conv2d_compute + ins : typing.List[te.Tensor] + The list of 2 Tensors which would be the input activations and weights + transform_activation_layout : str + The expected activations layout + transform_weights_layout : str + String representing the weights layout as defined in get_layout_transform_fn + transform_output_layout: str + String representing the output layout as defined in get_layout_transform_fn + + Returns + ------- + sch : te.Schedule + The TE schedule for slice conv2d + """ + activations, weights = ins + output = out + sch = tvm.te.create_schedule(output.op) + reduce_channel, reduce_height, reduce_width = sch[output].op.reduce_axis + sch[activations].transform_layout(get_layout_transform_fn(transform_activation_layout)) + sch[weights].transform_layout(get_layout_transform_fn(transform_weights_layout)) + transformed_axis = sch[output].transform_layout( + get_layout_transform_fn(transform_output_layout) + ) + fused_out_axis = sch[output].fuse(transformed_axis[-1], transformed_axis[-2]) + sch[output].reorder( + *[*transformed_axis[:-2], reduce_height, reduce_width, reduce_channel, fused_out_axis] + ) + # The below code doesn't work yet as vectorization across 2D boundary is not yet supported + # s[output].vectorize(fused_out_axis) + return sch + + +def conv2d_schedule( + outs: te.Tensor, + ins: typing.List[te.Tensor], + transform_activation_layout: str, + transform_weights_layout: str, + transform_output_layout: str, + output_name: str, +) -> tvm.tir.Schedule: + """STIR schedule definition for the compute defined above by conv2d_compute. + + - Auto-generated prim_func before applying schedule primitives for reference + - The below TVMScript code is for conv2d with padded input dimensions and a stride of 1x1 + + # from tvm.script import tir as T + @T.prim_func + def func(InputTensor: T.Buffer[(1, 24, 12, 32), "float16"], Weights: T.Buffer[(3, 3, 32, 32), "float16"], compute: T.Buffer[(1, 16, 8, 32), "float16"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 16, 8, 32, 32, 3, 3): + with T.block("compute"): + n, h, w, c, rc, rh, rw = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(InputTensor[n, h + rh, w + rw, rc], Weights[rh, rw, rc, c]) + T.writes(compute[n, h, w, c]) + with T.init(): + compute[n, h, w, c] = T.float16(0) + compute[n, h, w, c] = compute[n, h, w, c] + InputTensor[n, h + rh, w + rw, rc] * Weights[rh, rw, rc, c] + + Parameters + ---------- + outs : te.Tensor + The output Tensor as returned by a call to conv2d_compute + ins : typing.List[te.Tensor] + This is a list of 2 tensors - Input activations and Weights + transform_activation_layout : str + String representing the activations layout as defined in get_layout_transform_fn + transform_weights_layout : str + String representing the weights layout as defined in get_layout_transform_fn + transform_output_layout: str + String representing the output layout as defined in get_layout_transform_fn + output_name : str + The name that was given to the output compute and which can be used to get the block name + + Returns + ------- + sch : tvm.tir.Schedule + The STIR schedule for slice conv2d compute + """ + + assert len(ins) == 2, "This schedule expects only 2 inputs - Activations and Weights" + source_expr = ins + [outs] + prim_func = tvm.te.create_prim_func(source_expr) + sch = tvm.tir.Schedule(prim_func) + + compute = sch.get_block(output_name) + # Apply layout_transform for activation + sch.transform_layout(compute, ins[0].name, get_layout_transform_fn(transform_activation_layout)) + + # Apply layout_transform for weights + sch.transform_layout(compute, ins[1].name, get_layout_transform_fn(transform_weights_layout)) + + # Apply layout_transform for output + sch.transform_layout(compute, outs.name, get_layout_transform_fn(transform_output_layout)) + + batch, height, width, channel, reduce_channel, reduce_height, reduce_width = sch.get_loops( + compute + ) # This still returns the original 7d loop + h_outer, h_inner = sch.split(height, [None, 8]) + w_outer, w_inner = sch.split(width, [None, 4]) + w_inner_outer, w_inner_inner = sch.split(w_inner, [2, 2]) + c_outer, c_inner = sch.split(channel, [None, 32]) + sch.reorder( + batch, + h_outer, + w_outer, + c_outer, + h_inner, + w_inner_outer, + reduce_height, + reduce_width, + reduce_channel, + c_inner, + w_inner_inner, + ) + sch.decompose_reduction(compute, reduce_height) + # ci_wii = s.fuse(ci, wii) + # s.vectorize(ci_wii) + return sch diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 3efc48c4d04fa..170c2ce88d377 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -62,6 +62,18 @@ def nc_512c_2d(n, c): return [n, c // 512, te.AXIS_SEPARATOR, c % 512] +def iohw_16i32o2i_1d(height, width, in_channel, out_channel): + return [ + in_channel // 32, + out_channel // 32, + height, + width, + (in_channel % 32) // 2, + out_channel % 32, + in_channel % 2, + ] + + def get_layout_transform_fn(layout): """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": @@ -80,4 +92,6 @@ def get_layout_transform_fn(layout): return nc_512c_2d if layout == "nc-512c-1d": return nc_512c_1d + if layout == "iohw-16i32o2i-1d": + return iohw_16i32o2i_1d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_slice.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_slice.py new file mode 100755 index 0000000000000..a03c35cb9e78f --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_slice.py @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long, redefined-outer-name + +"""Test conv2d slice op for hexagon""" + +import numpy as np + +import tvm +import tvm.testing +from tvm.topi.hexagon.slice_ops.conv2d import conv2d_compute, conv2d_schedule +from tvm.topi.testing import conv2d_nhwc_python + +from ..infrastructure import allocate_hexagon_array, transform_numpy + +input_layout = tvm.testing.parameter( + "nhwc-8h2w32c2w-2d", +) + +output_layout = tvm.testing.parameter( + "nhwc-8h2w32c2w-2d", +) + +weights_layout = tvm.testing.parameter("iohw-16i32o2i-1d") + + +@tvm.testing.fixture +def input_np(in_shape, dtype): + return np.random.uniform(size=in_shape).astype(dtype) + + +@tvm.testing.fixture +def weights_np(filt_shape, dtype): + return (np.random.uniform(size=filt_shape)).astype(dtype) + + +@tvm.testing.fixture +def dilated_filt_shape(filt_shape, dilation): + """Compute the dilated filter shape when dilation > 1""" + filt_height, filt_width, in_channel, out_channel = filt_shape + dilation_height, dilation_width = dilation + if dilation_height == 1 and dilation_width == 1: + return filt_shape + dilated_height, dilated_width = ( + dilation_height * (filt_height - 1) + 1, + dilation_width * (filt_width - 1) + 1, + ) + return dilated_height, dilated_width, in_channel, out_channel + + +@tvm.testing.fixture +def dilated_weights_np(weights_np, dilation, dilated_filt_shape): + """Get dilated weights from original weights for testing""" + filt_height, filt_width, in_channels, out_channels = weights_np.shape + dilation_height, dilation_width = dilation + if dilation_height == 1 and dilation_width == 1: + return weights_np + dilated_height, dilated_width = dilated_filt_shape[0], dilated_filt_shape[1] + dilated_weights = np.zeros(dilated_filt_shape, dtype="float16") + for in_channel in range(in_channels): + for out_channel in range(out_channels): + for dilation_i, height_i in zip( + range(0, dilated_height, dilation_height), range(filt_height) + ): + for dilation_j, width_j in zip( + range(0, dilated_width, dilation_width), range(filt_width) + ): + dilated_weights[dilation_i, dilation_j, in_channel, out_channel] = weights_np[ + height_i, width_j, in_channel, out_channel + ] + + return dilated_weights + + +@tvm.testing.fixture +def input_np_padded(input_np, in_shape, padded_in_shape): + pad_height = padded_in_shape[1] - in_shape[1] + pad_width = padded_in_shape[2] - in_shape[2] + pad_channel = padded_in_shape[3] - in_shape[3] + input_padded = np.pad( + input_np, ((0, 0), (0, pad_height), (0, pad_width), (0, pad_channel)), "constant" + ) + return input_padded + + +@tvm.testing.fixture +def padded_filt_shape(filt_shape): + filt_height, filt_width, in_channels, out_channels = filt_shape + in_channels = ((in_channels + 31) // 32) * 32 + out_channels = ((out_channels + 31) // 32) * 32 + return filt_height, filt_width, in_channels, out_channels + + +@tvm.testing.fixture +def weights_np_padded(weights_np, filt_shape, padded_filt_shape): + pad_in_channels = padded_filt_shape[2] - filt_shape[2] + pad_out_channels = padded_filt_shape[3] - filt_shape[3] + filt_padded = np.pad(weights_np, ((0, 0), (0, 0), (0, pad_in_channels), (0, pad_out_channels))) + return filt_padded + + +@tvm.testing.fixture +def weights_np_transformed(weights_np_padded): + height, width, in_channel, out_channel = weights_np_padded.shape + weights_np_reverse_width = weights_np_padded[:, ::-1, :, :] + transformed_weights_np = weights_np_reverse_width.reshape( + [height, width, in_channel // 32, 16, 2, out_channel // 32, 32] + ).transpose(2, 5, 0, 1, 3, 6, 4) + return transformed_weights_np + + +def generate_test_config(test_params): + """Utility function to generate test config with meaningful ids""" + test_config = {} + + dims = lambda vals: "x".join(map(str, vals)) + + for param in test_params: + in_shape, filt_shape, stride, dilation = param + test_name = f"nhwc{dims(in_shape)}-hwio{dims(filt_shape)}-stride{dims(stride)}-dilation{dims(dilation)}" + test_config[test_name] = param + + return test_config + + +class TestConv2dSlice: + """Test class that defines the conv2d slice test""" + + test_params = [ + [ + (1, 10, 6, 32), + (3, 3, 32, 32), + (1, 1), + (1, 1), + ], + [ + (1, 18, 10, 32), + (3, 3, 32, 32), + (1, 1), + (1, 1), + ], + [ + (1, 10, 6, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + ], + [ + (1, 12, 8, 4), + (3, 3, 4, 32), + (1, 1), + (2, 2), + ], + [ + (1, 12, 8, 32), + (5, 5, 32, 32), + (1, 1), + (1, 1), + ], + [ + (1, 16, 12, 32), + (5, 5, 32, 32), + (1, 1), + (2, 2), + ], + [ + (1, 13, 9, 32), + (6, 6, 32, 32), + (1, 1), + (1, 1), + ], + [ + (1, 18, 10, 32), + (3, 3, 32, 32), + (2, 2), + (1, 1), + ], + [ + (1, 20, 12, 32), + (5, 5, 32, 32), + (2, 2), + (1, 1), + ], + [ + (1, 22, 14, 32), + (7, 7, 32, 32), + (2, 2), + (1, 1), + ], + [ + (1, 28, 20, 32), + (7, 7, 32, 32), + (2, 2), + (2, 2), + ], + [ + (1, 10, 4, 4), + (3, 1, 4, 32), + (1, 1), + (1, 1), + ], + [ + (1, 18, 8, 4), + (3, 1, 4, 32), + (2, 2), + (1, 1), + ], + [ + (1, 20, 8, 4), + (3, 1, 4, 32), + (2, 2), + (2, 2), + ], + ] + + test_config = generate_test_config(test_params) + + in_shape, filt_shape, stride, dilation = tvm.testing.parameters( + *test_config.values(), ids=test_config.keys() + ) + dtype = tvm.testing.parameter("float16") + working_scope = tvm.testing.parameter("global.vtcm") + + @tvm.testing.fixture + def padded_in_shape(self, in_shape): + in_batch, in_height, in_width, in_channel = in_shape + in_height = ((in_height + 7) // 8) * 8 + in_width = ((in_width + 3) // 4) * 4 + in_channel = ((in_channel + 31) // 32) * 32 + return in_batch, in_height, in_width, in_channel + + @tvm.testing.fixture + def out_shape(self, in_shape, dilated_filt_shape, stride): + in_batch, in_height, in_width, _ = in_shape + filt_height, filt_width, _, num_filt = dilated_filt_shape + out_height = (in_height - filt_height) // stride[0] + 1 + out_width = (in_width - filt_width) // stride[1] + 1 + out_channel = num_filt + return in_batch, out_height, out_width, out_channel + + @tvm.testing.fixture + def expected_output_np(self, input_np, dilated_weights_np, stride): + ref_np = conv2d_nhwc_python( + input_np.astype("float32"), dilated_weights_np.astype("float32"), stride, padding=0 + ).astype("float16") + return ref_np + + @tvm.testing.requires_hexagon + def test_conv2d( + self, + padded_in_shape, + padded_filt_shape, + stride, + dilation, + dtype, + out_shape, + input_layout, + weights_layout, + output_layout, + input_np_padded, + weights_np_transformed, + expected_output_np, + target, + working_scope, + hexagon_session, + ): + """Main test function that tests the conv2d slice op""" + input_tensor = tvm.te.placeholder(padded_in_shape, name="InputTensor", dtype=dtype) + weights = tvm.te.placeholder(padded_filt_shape, name="Weights", dtype=dtype) + output_name = "output" + + output_tensor = conv2d_compute( + input_tensor, weights, out_shape, stride, dilation, dtype, output_name + ) + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + tir_schedule = conv2d_schedule( + output_tensor, + [input_tensor, weights], + input_layout, + weights_layout, + output_layout, + output_name, + ) + + func_name = f"fconv2d_{dtype}" + with tvm.transform.PassContext(opt_level=3): + runtime_module = tvm.build( + tir_schedule.mod, + target=target, + name=func_name, + ) + + input_np_transformed = transform_numpy(input_np_padded, "nhwc", input_layout) + output_np_transformed = transform_numpy(expected_output_np, "nhwc", output_layout) + + input_arr = allocate_hexagon_array( + hexagon_session.device, + data=input_np_transformed, + axis_separators=[4], + mem_scope=working_scope, + ) + + weights_arr = allocate_hexagon_array( + hexagon_session.device, data=weights_np_transformed, mem_scope=working_scope + ) + + output_arr = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=output_np_transformed.shape, + dtype=output_np_transformed.dtype, + axis_separators=[4], + mem_scope=working_scope, + ) + + mod = hexagon_session.load_module(runtime_module) + mod(input_arr, weights_arr, output_arr) + output_np = output_arr.numpy() + np.testing.assert_allclose(output_np, output_np_transformed, atol=1.0, rtol=0.05) + + +if __name__ == "__main__": + tvm.testing.main()