Skip to content

Commit

Permalink
dequantize op hexagon
Browse files Browse the repository at this point in the history
  • Loading branch information
aakaverm committed Sep 2, 2022
1 parent e814f79 commit 6214da7
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/tvm/topi/hexagon/qnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@
""" Computes and schedules for Hexagon quantized ops """

from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule

from .dequantize import (
dequantize_compute,
dequantize_schedule,
)
94 changes: 94 additions & 0 deletions python/tvm/topi/hexagon/qnn/dequantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

""" Hexagon qnn.dequantize slice op compute and schedule"""

from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn


def dequantize_compute(tensor_A, scale_A, zero_point_A):

return te.compute(
tensor_A.shape,
lambda *indices: (scale_A * (tensor_A[indices] - zero_point_A)).astype("float32"),
name="dequantize",
)


def dequantize_stir_schedule_nhwc_8h8w32c(
_in,
_out,
in_layout,
out_layout,
):
"""Schedule for nhwc int8/uint8 to f32 : nhwc layout"""
func = te.create_prim_func([_in, _out])
sch = tir.Schedule(func, debug_mask="all")
block_name = "dequantize"
n, h, w, c = sch.get_loops(sch.get_block(block_name))
ho, hi = sch.split(h, [None, 4])
wo, wi = sch.split(w, [None, 8])
wio, wii = sch.split(wi, [None, 4])
co, ci = sch.split(c, [None, 32])
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
sch.reorder(n, ho, wo, co, hi, wio, wii, ci)
wii_ci = sch.fuse(wii, ci)
sch.vectorize(wii_ci)
return sch


def dequantize_stir_schedule_nc(
_in,
_out,
in_layout,
out_layout,
):
"""Schedule for nc int8/uint8 to f32 : nc layout"""
func = te.create_prim_func([_in, _out])
sch = tir.Schedule(func, debug_mask="all")
block_name = "dequantize"
_, c_orig = sch.get_loops(sch.get_block(block_name))
_, c_inner = sch.split(c_orig, [None, 512])
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
sch.vectorize(c_inner)
return sch


def dequantize_schedule(_in, _output, in_layout_str, out_layout_str):
"""Schedule for int8/uint8 to f32 : top level function"""
f32_layout_transform_func = get_layout_transform_fn(out_layout_str)
in_layout_transform_func = get_layout_transform_fn(in_layout_str)
if out_layout_str == "nhwc-4h2w32c2w-2d":
return dequantize_stir_schedule_nhwc_8h8w32c(
_in,
_output,
in_layout_transform_func,
f32_layout_transform_func,
)
if out_layout_str == "nc-512c-2d":
return dequantize_stir_schedule_nc(
_in,
_output,
in_layout_transform_func,
f32_layout_transform_func,
)
raise RuntimeError(f"Unexpected layout '{layout}'")
7 changes: 7 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def nc_2048_2d(n, c):
return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]


def nc_2048c_2d(n, c):
"""Return index map for nc_2048 2d layout"""
return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]


def nhwc_8h8w32c_2d(n, h, w, c):
"""Return index map for nhwc_8h8w32c 2d layout"""
return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32]
Expand Down Expand Up @@ -156,6 +161,8 @@ def get_layout_transform_fn(layout):
return nhwc_2048c_2d
if layout == "nc-2048-2d":
return nc_2048_2d
if layout == "nc-2048c-2d":
return nc_2048c_2d
if layout == "nhwc-8h8w32c-2d":
return nhwc_8h8w32c_2d
if layout == "n11c-2048c-2d":
Expand Down
2 changes: 2 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
return arr_np.reshape([n, c // 1024, 1024])
if new_layout in ["nc-512c-2d"]:
return arr_np.reshape([n, c // 512, 512])
if new_layout in ["nc-2048c-2d"]:
return arr_np.reshape([n, c // 2048, 2048])
raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

if current_layout == "nhw":
Expand Down
121 changes: 121 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_dequantize_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

""" Tests for Hexagon dequantize """
import numpy as np

import tvm
import tvm.testing
from tvm import te
from tvm.topi.hexagon import qnn
from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np


class TestDequantizeSlice2d:
"""
For testing Dequantize Slice ops
"""

input_shape, orig_layout, input_layout, output_layout, axis_sep, dtype = tvm.testing.parameters(
((1, 16, 64, 128), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "int8"),
((1, 16, 64, 128), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "uint8"),
((1, 8, 8, 32), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "int8"),
((1, 8, 8, 32), "nhwc", "nhwc-8h8w32c-2d", "nhwc-4h2w32c2w-2d", [4], "uint8"),
((1, 2048), "nc", "nc-2048c-2d", "nc-512c-2d", [2], "int8"),
((1, 2048), "nc", "nc-2048c-2d", "nc-512c-2d", [2], "uint8"),
)

working_scope = tvm.testing.parameter("global.vtcm")

@tvm.testing.fixture
def input_np(self, input_shape):
arr_np = np.random.random(size=input_shape).astype("float32")
return arr_np

@tvm.testing.fixture
def transformed_input_np(self, input_np, orig_layout, input_layout, dtype):
quant_arr, scale, zero_point = quantize_np(input_np, dtype)
return [transform_numpy(quant_arr, orig_layout, input_layout), scale, zero_point]

@tvm.testing.fixture
def expected_output_np(self, input_np, dtype):
quant_np, scale, zero_point = quantize_np(input_np, dtype)
ref_np = (scale * (quant_np.astype("int32") - zero_point)).astype("float32")
return ref_np

@tvm.testing.fixture
def transformed_expected_output_np(self, expected_output_np, orig_layout, output_layout):
return transform_numpy(expected_output_np, orig_layout, output_layout)

@tvm.testing.requires_hexagon
def test_dequant_qnn(
self,
input_shape,
dtype,
input_layout,
output_layout,
transformed_input_np,
transformed_expected_output_np,
axis_sep,
hexagon_session,
working_scope,
):
"""
Top level testing function for dequantize
"""
target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)

dequant_input = te.placeholder(input_shape, name="A", dtype=dtype)

in_data_np, in_scale, in_zero_pt = transformed_input_np

dequant_output = qnn.dequantize_compute(dequant_input, in_scale, in_zero_pt)

tir_s = qnn.dequantize_schedule(dequant_input, dequant_output, input_layout, output_layout)

input_data = allocate_hexagon_array(
hexagon_session.device,
data=in_data_np,
axis_separators=axis_sep,
mem_scope=working_scope,
)
output_data = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=transformed_expected_output_np.shape,
dtype=transformed_expected_output_np.dtype,
axis_separators=axis_sep,
mem_scope=working_scope,
)
with tvm.transform.PassContext(opt_level=3):
tir_irm = tvm.lower(tir_s.mod, [dequant_input, dequant_output], name="dequantize")
runtime_module = tvm.build(tir_irm, target=target, name="dequantize")
mod = hexagon_session.load_module(runtime_module)

mod(input_data, output_data)
output_np = output_data.numpy()
tvm.testing.assert_allclose(
output_np,
transformed_expected_output_np,
1e-3,
1e-3,
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 6214da7

Please sign in to comment.