diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index e27e3793d565..25d1e6d1854d 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -18,3 +18,8 @@ """ Computes and schedules for Hexagon quantized ops """ from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule + +from .dequantize import ( + dequantize_compute, + dequantize_schedule, +) diff --git a/python/tvm/topi/hexagon/qnn/dequantize.py b/python/tvm/topi/hexagon/qnn/dequantize.py new file mode 100644 index 000000000000..3e1466e88b38 --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/dequantize.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +""" Hexagon qnn.dequantize slice op compute and schedule""" + +from tvm import te +from tvm import tir +from ..utils import get_layout_transform_fn + + +def dequantize_compute(tensor_A, scale_A, zero_point_A): + + return te.compute( + tensor_A.shape, + lambda *indices: (scale_A * (tensor_A[indices] - zero_point_A)).astype("float32"), + name="dequantize", + ) + + +def dequantize_stir_schedule_nhwc_8h8w32c( + _in, + _out, + in_layout, + out_layout, +): + """Schedule for nhwc int8/uint8 to f32 : nhwc layout""" + func = te.create_prim_func([_in, _out]) + sch = tir.Schedule(func, debug_mask="all") + block_name = "dequantize" + n, h, w, c = sch.get_loops(sch.get_block(block_name)) + ho, hi = sch.split(h, [None, 4]) + wo, wi = sch.split(w, [None, 8]) + wio, wii = sch.split(wi, [None, 4]) + co, ci = sch.split(c, [None, 32]) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + sch.reorder(n, ho, wo, co, hi, wio, wii, ci) + wii_ci = sch.fuse(wii, ci) + sch.vectorize(wii_ci) + return sch + + +def dequantize_stir_schedule_nc( + _in, + _out, + in_layout, + out_layout, +): + """Schedule for nc int8/uint8 to f32 : nc layout""" + func = te.create_prim_func([_in, _out]) + sch = tir.Schedule(func, debug_mask="all") + block_name = "dequantize" + _, c_orig = sch.get_loops(sch.get_block(block_name)) + _, c_inner = sch.split(c_orig, [None, 512]) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + sch.vectorize(c_inner) + return sch + + +def dequantize_schedule(_in, _output, in_layout_str, out_layout_str): + """Schedule for int8/uint8 to f32 : top level function""" + f32_layout_transform_func = get_layout_transform_fn(out_layout_str) + in_layout_transform_func = get_layout_transform_fn(in_layout_str) + if out_layout_str == "nhwc-4h2w32c2w-2d": + return dequantize_stir_schedule_nhwc_8h8w32c( + _in, + _output, + in_layout_transform_func, + f32_layout_transform_func, + ) + if out_layout_str == "nc-512c-2d": + return dequantize_stir_schedule_nc( + _in, + _output, + in_layout_transform_func, + f32_layout_transform_func, + ) + raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index c056408947b7..9939e5b6fbb7 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -100,6 +100,11 @@ def nc_2048_2d(n, c): return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] +def nc_2048c_2d(n, c): + """Return index map for nc_2048 2d layout""" + return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048] + + def nhwc_8h8w32c_2d(n, h, w, c): """Return index map for nhwc_8h8w32c 2d layout""" return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32] @@ -156,6 +161,8 @@ def get_layout_transform_fn(layout): return nhwc_2048c_2d if layout == "nc-2048-2d": return nc_2048_2d + if layout == "nc-2048c-2d": + return nc_2048c_2d if layout == "nhwc-8h8w32c-2d": return nhwc_8h8w32c_2d if layout == "n11c-2048c-2d": diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 70e50fcb68d6..71960b649ea2 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -295,6 +295,8 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): return arr_np.reshape([n, c // 1024, 1024]) if new_layout in ["nc-512c-2d"]: return arr_np.reshape([n, c // 512, 512]) + if new_layout in ["nc-2048c-2d"]: + return arr_np.reshape([n, c // 2048, 2048]) raise RuntimeError(f"Unexpected new_layout '{new_layout}'") if current_layout == "nhw": diff --git a/tests/python/contrib/test_hexagon/topi/test_dequantize_slice.py b/tests/python/contrib/test_hexagon/topi/test_dequantize_slice.py new file mode 100644 index 000000000000..e9b3dd132692 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_dequantize_slice.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +""" Tests for Hexagon dequantize """ +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.topi.hexagon import qnn +from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np + + +class TestDequantizeSlice2d: + """ + For testing Dequantize Slice ops + """ + + input_shape, orig_layout, input_layout, output_layout, axis_sep, dtype = tvm.testing.parameters( + ((1, 16, 64, 128), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "int8"), + ((1, 16, 64, 128), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "uint8"), + ((1, 8, 8, 32), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "int8"), + ((1, 8, 8, 32), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "uint8"), + ((1, 2048), "nc", "nc-2048c-2d", "nc-512c-2d", [2], "int8"), + ((1, 2048), "nc", "nc-2048c-2d", "nc-512c-2d", [2], "uint8"), + ) + + working_scope = tvm.testing.parameter("global.vtcm") + + @tvm.testing.fixture + def input_np(self, input_shape): + arr_np = np.random.random(size=input_shape).astype("float32") + return arr_np + + @tvm.testing.fixture + def transformed_input_np(self, input_np, orig_layout, input_layout, dtype): + quant_arr, scale, zero_point = quantize_np(input_np, dtype) + return [transform_numpy(quant_arr, orig_layout, input_layout), scale, zero_point] + + @tvm.testing.fixture + def expected_output_np(self, input_np, dtype): + quant_np, scale, zero_point = quantize_np(input_np, dtype) + ref_np = (scale * (quant_np.astype("int32") - zero_point)).astype("float32") + return ref_np + + @tvm.testing.fixture + def transformed_expected_output_np(self, expected_output_np, orig_layout, output_layout): + return transform_numpy(expected_output_np, orig_layout, output_layout) + + @tvm.testing.requires_hexagon + def test_dequant_qnn( + self, + input_shape, + dtype, + input_layout, + output_layout, + transformed_input_np, + transformed_expected_output_np, + axis_sep, + hexagon_session, + working_scope, + ): + """ + Top level testing function for dequantize + """ + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + dequant_input = te.placeholder(input_shape, name="A", dtype=dtype) + + in_data_np, in_scale, in_zero_pt = transformed_input_np + + dequant_output = qnn.dequantize_compute(dequant_input, in_scale, in_zero_pt) + + tir_s = qnn.dequantize_schedule(dequant_input, dequant_output, input_layout, output_layout) + + input_data = allocate_hexagon_array( + hexagon_session.device, + data=in_data_np, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + output_data = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=transformed_expected_output_np.shape, + dtype=transformed_expected_output_np.dtype, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + with tvm.transform.PassContext(opt_level=3): + tir_irm = tvm.lower(tir_s.mod, [dequant_input, dequant_output], name="dequantize") + runtime_module = tvm.build(tir_irm, target=target, name="dequantize") + mod = hexagon_session.load_module(runtime_module) + + mod(input_data, output_data) + output_np = output_data.numpy() + tvm.testing.assert_allclose( + output_np, + transformed_expected_output_np, + 1e-3, + 1e-3, + ) + + +if __name__ == "__main__": + tvm.testing.main()