Skip to content

Commit

Permalink
[TOPI][Hexagon] Implement quantize op for hexagon (#12820)
Browse files Browse the repository at this point in the history
* [TOPI][Hexagon] Implement quantize op for hexagon

* Fix lint issue
  • Loading branch information
trahman-quic authored Sep 26, 2022
1 parent fd26813 commit e1f3f90
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 2 deletions.
2 changes: 2 additions & 0 deletions python/tvm/topi/hexagon/qnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@
dequantize_compute,
dequantize_schedule,
)

from .quantize import quantize_compute, tir_quantize_schedule
80 changes: 80 additions & 0 deletions python/tvm/topi/hexagon/qnn/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Compute and schedule for hexagon quantize
Please note the following assumptions made by the implementation:
1) The input and output data will be multiple of crouton layout
2) And the supported layout is NHWC
3) The input layout will be nhwc-4h2w32c2w-2d and
output layout will be nhwc-8h8w32c-2d"""


from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn, saturate


def quantize_compute(tensor_A: te.Tensor, scale: float, zero_point: int, dtype: str):
"""Compute for quantize"""
scale_recip = 1 / scale

return te.compute(
tensor_A.shape,
lambda n, h, w, c: saturate(
((tensor_A[n, h, w, c] * scale_recip).astype("int32") + zero_point),
dtype,
).astype(dtype),
name="quantize",
)


def tir_quantize_schedule(
out_M: te.Tensor,
tensor_A: te.Tensor,
input_layout: str,
output_layout: str,
):
"""Schedule for output layout nhwc-8h8w32c-2d"""
func = te.create_prim_func([tensor_A, out_M])

s = tir.Schedule(func)

block = s.get_block("quantize")

input_transformed_layout = get_layout_transform_fn(input_layout)
s.transform_layout(block, buffer=tensor_A.name, index_map=input_transformed_layout)

output_transformed_layout = get_layout_transform_fn(output_layout)
s.transform_layout(block, buffer=out_M.name, index_map=output_transformed_layout)

# Fixed chunk size is 2048 byte
# For uint8 the layout for fixed chunk is 8x8x32
# where each element is 1 bytes
# Split and reorder is done to iterate over the fixed chunk
# Channel is split by a factor of 32
# Width is split by a factor of 8
# Height is split by a factor of 8
n, h, w, c = s.get_loops(block)

h_o, h_i = s.split(h, [None, 8])
w_o, w_i = s.split(w, [None, 8])
c_o, c_i = s.split(c, [None, 32])
wio, wii = s.split(w_i, [None, 4])

s.reorder(n, h_o, w_o, c_o, h_i, wio, wii, c_i)

return s
5 changes: 5 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,8 @@ def within_range(val, dtype):
fixed_point_value = int(round(flp * scale_f[0]))

return fixed_point_value, exp_scale_factor


def saturate(x: te.Tensor, dtype: str):
"""Saturate value for the specified data type"""
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))
4 changes: 2 additions & 2 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ def quantize_np(arr_np: numpy.ndarray, dtype: str):
qmax = 255
qmin = 0
elif dtype == "int8":
qmax = 128
qmin = -127
qmax = 127
qmin = -128
else:
raise RuntimeError(f"Unsupported quantized data type '{dtype}'")
fmin = numpy.amin(arr_np)
Expand Down
121 changes: 121 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np

import tvm
from tvm import te
import tvm.topi.hexagon.qnn as s1
from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np


@tvm.testing.fixture
def expected_output_np(input_np, output_dtype):
global scale, zero_point
quant_np, scale, zero_point = quantize_np(input_np, output_dtype)
return quant_np


@tvm.testing.fixture
def input_np(input_shape, input_dtype):
return np.random.random(input_shape).astype(input_dtype)


@tvm.testing.fixture
def transformed_input_np(input_np, input_crouton_layout):
return transform_numpy(input_np, "nhwc", input_crouton_layout)


@tvm.testing.fixture
def transformed_expected_output_np(expected_output_np, output_layout):
return transform_numpy(expected_output_np, "nhwc", output_layout)


class TestQuantize:
input_crouton_layout, output_layout, input_dtype = tvm.testing.parameters(
("nhwc-4h2w32c2w-2d", "nhwc-8h8w32c-2d", "float32"),
)

output_dtype = tvm.testing.parameter("uint8", "int8")

input_shape = tvm.testing.parameter(
(1, 8, 8, 32), (1, 16, 16, 32), (1, 16, 16, 128), (1, 64, 64, 64)
)

@tvm.testing.requires_hexagon
def test_quantize(
self,
input_dtype,
output_dtype,
input_np,
transformed_input_np,
input_shape,
expected_output_np,
transformed_expected_output_np,
input_crouton_layout,
output_layout,
hexagon_session,
):
target_hexagon = tvm.target.hexagon("v69")
A = te.placeholder(input_shape, name="A", dtype=input_dtype)

M = s1.quantize_compute(A, scale, zero_point, output_dtype)

tir_schedule = s1.tir_quantize_schedule(M, A, input_crouton_layout, output_layout)

sch = tir_schedule.mod

input_axis_separator = [4]
output_axis_separator = [4]

with tvm.transform.PassContext(opt_level=3):
func = tvm.build(
sch,
[A, M],
tvm.target.Target(target_hexagon, host=target_hexagon),
name="quantize",
)

A_data_nd = allocate_hexagon_array(
hexagon_session.device,
data=transformed_input_np,
dtype=input_dtype,
axis_separators=input_axis_separator,
mem_scope="global.vtcm",
)

M_data_nd = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=transformed_expected_output_np.shape,
dtype=output_dtype,
axis_separators=output_axis_separator,
mem_scope="global.vtcm",
)

mod = hexagon_session.load_module(func)
mod(A_data_nd, M_data_nd)

b, h, w, c = expected_output_np.shape

# convert nd to np and reshape to fixed chunk size layout
M_data_np = M_data_nd.numpy().reshape([b, h // 8, w // 8, c // 32, 8, 8, 32])

np.testing.assert_allclose(transformed_expected_output_np, M_data_np, atol=1)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit e1f3f90

Please sign in to comment.