Skip to content

Commit

Permalink
[CMSIS-NN] Support int16 handling for pooling functions (#13498)
Browse files Browse the repository at this point in the history
[CMSIS-NN] Support int16 handling for pooling functions

-Pattern matching and RelayToTIR introduce int16 support
-Added int16 variants to fully_connected tests
  • Loading branch information
neildhickey authored Nov 29, 2022
1 parent c0ba8a1 commit f6f7fea
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
12 changes: 8 additions & 4 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,10 @@ def check_qnn_avg_pool2d(pattern):
return (
pooling.attrs.layout == "NHWC"
and int(input_op.checked_type.shape[0]) == 1
and input_op.checked_type.dtype == "int8"
and output.checked_type.dtype == "int8"
and (
(input_op.checked_type.dtype == "int8" and output.checked_type.dtype == "int8")
or (input_op.checked_type.dtype == "int16" and output.checked_type.dtype == "int16")
)
)

def qnn_max_pool2d_pattern():
Expand All @@ -310,8 +312,10 @@ def check_qnn_max_pool2d(pattern):
return (
pooling.attrs.layout == "NHWC"
and int(input_op.checked_type.shape[0]) == 1
and input_op.checked_type.dtype == "int8"
and output.checked_type.dtype == "int8"
and (
(input_op.checked_type.dtype == "int8" and output.checked_type.dtype == "int8")
or (input_op.checked_type.dtype == "int16" and output.checked_type.dtype == "int16")
)
)

def binary_op_pattern(op):
Expand Down
29 changes: 23 additions & 6 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,19 @@ class RelayToTIRVisitor : public MixedModeMutator {
pool = final_call;
}

int32_t dtype_bits = final_call->type_as<TensorTypeNode>()->dtype.bits();

// prepare cmsis_nn_pool_params
int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
int32_t clip_min, clip_max;
std::string cmsisnn_api;
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
cmsisnn_api = "arm_avgpool_s8";
if (dtype_bits == 8) {
cmsisnn_api = "arm_avgpool_s8";
} else {
cmsisnn_api = "arm_avgpool_s16";
}

const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
Expand All @@ -442,7 +449,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
} else {
cmsisnn_api = "arm_max_pool_s8";
if (dtype_bits == 8) {
cmsisnn_api = "arm_max_pool_s8";
} else {
cmsisnn_api = "arm_max_pool_s16";
}

const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
Expand All @@ -456,8 +468,13 @@ class RelayToTIRVisitor : public MixedModeMutator {
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
if (dtype_bits == 8) {
clip_min = std::numeric_limits<int8_t>::min();
clip_max = std::numeric_limits<int8_t>::max();
} else {
clip_min = std::numeric_limits<int16_t>::min();
clip_max = std::numeric_limits<int16_t>::max();
}
}

tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
Expand All @@ -472,8 +489,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};

BufferCreator buffer_creator;
tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(dtype_bits));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(dtype_bits));
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};

int context_buffer_size = 0;
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
} else if (cmsis_func_name == "arm_fully_connected_s8" ||
cmsis_func_name == "arm_fully_connected_s16") {
EmitFullyConnected(op);
} else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_max_pool_s8") {
} else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_avgpool_s16" ||
cmsis_func_name == "arm_max_pool_s8" || cmsis_func_name == "arm_max_pool_s16") {
EmitPool2D(op);
}
return;
Expand Down
12 changes: 6 additions & 6 deletions tests/python/contrib/test_cmsisnn/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def make_model(


@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("dtype", ["int16", "int8"])
@pytest.mark.parametrize("in_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
@pytest.mark.parametrize(
"pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), "VALID")]
Expand All @@ -91,7 +92,8 @@ def make_model(
@pytest.mark.parametrize(
"compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")]
)
def test_op_int8(
def test_ops(
dtype,
in_shape,
pool_size,
strides,
Expand All @@ -103,18 +105,17 @@ def test_op_int8(
compiler_cpu,
cpu_flags,
):
"""Tests QNN pooling op for int8 inputs"""
"""Tests QNN pooling op for int8 and int16 pooling"""
interface_api = "c"
use_unpacked_api = True

dtype = "int8"

model = make_model(
pool_op=pool_type,
shape=in_shape,
pool_size=pool_size,
strides=strides,
padding=padding,
dtype=dtype,
scale=scale,
zero_point=zero_point,
relu_type=relu_type,
Expand All @@ -130,7 +131,7 @@ def test_op_int8(
in_min, in_max = get_range_for_dtype_str(dtype)
np.random.seed(0)
inputs = {
"input": np.random.randint(in_min, high=in_max, size=in_shape, dtype="int8"),
"input": np.random.randint(in_min, high=in_max, size=in_shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
Expand Down Expand Up @@ -211,7 +212,6 @@ def test_int8_pool_with_float32_input(
def test_invalid_datatype(op):
"""Checks CMSIS-NN partitioning for non int8 dtype"""
model = make_model(pool_op=op, dtype="int64")

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
assert_no_external_function(cmsisnn_mod)
Expand Down

0 comments on commit f6f7fea

Please sign in to comment.