diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 8964937469c4..779fe35c3718 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -231,27 +231,40 @@ def check_qnn_fully_connected(pattern): requantize = pattern requantize_input = requantize.args[0] bias_add = None - bias_dtype = "int32" if str(requantize_input.op.name) == "nn.bias_add": bias_add = requantize_input fc = bias_add.args[0] - bias_dtype = bias_add.args[1].checked_type.dtype else: fc = requantize_input fc_input = fc.args[0] fc_weight = fc.args[1] + are_dtypes_valid = False + fc_input_dtype = fc_input.checked_type.dtype + if bias_add: + bias_dtype = bias_add.args[1].checked_type.dtype + else: + bias_dtype = "int32" if fc_input_dtype == "int8" else "int64" + + valid_dtypes = None + if fc_input_dtype == "int8": + valid_dtypes = ("int8", "int8", "int32", "int32", "int8") + elif fc_input_dtype == "int16": + valid_dtypes = ("int16", "int8", "int64", "int64", "int16") + + if ( + fc_input_dtype, + fc_weight.checked_type.dtype, + bias_dtype, + fc.attrs.out_dtype, + pattern.checked_type.dtype, + ) == valid_dtypes: + are_dtypes_valid = True + # kernel zero_point should be 0 kernel_zp = fc.args[3].data.numpy().item(0) - return ( - fc.attrs.out_dtype == "int32" - and fc_input.checked_type.dtype == "int8" - and fc_weight.checked_type.dtype == "int8" - and pattern.checked_type.dtype == "int8" - and bias_dtype == "int32" - and kernel_zp == 0 - ) + return are_dtypes_valid and kernel_zp == 0 def qnn_avg_pool2d_pattern(): """Matches average pooling with optional Relu""" diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 1ea020e884de..c9e41589fb4b 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -192,6 +192,11 @@ class RelayToTIRVisitor : public MixedModeMutator { std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); int32_t clip_min = std::numeric_limits::min(); int32_t clip_max = std::numeric_limits::max(); + + if (dtype_bits == 16) { + clip_min = std::numeric_limits::min(); + clip_max = std::numeric_limits::max(); + } if (clip_call) { const ClipAttrs* clip_attrs = clip_call->attrs.as(); clip_min = clip_attrs->a_min; @@ -309,6 +314,14 @@ class RelayToTIRVisitor : public MixedModeMutator { fc_call = requantize_input; } + // Extract the size of the input parameter from the call arguments. Other params are based off + // the input size + int32_t dtype_bits = fc_call->args[0]->type_as()->dtype.bits(); + int32_t input_bits = dtype_bits; + int32_t filter_bits = 8; + int32_t bias_bits = dtype_bits * 4U; + int32_t output_bits = dtype_bits; + // TIR variables are created in the order they appear in the Relay partitioned function // %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar, // %input_scale_scalar, %kernel_scale_scalar) @@ -317,12 +330,12 @@ class RelayToTIRVisitor : public MixedModeMutator { // %output_scale_scalar, %output_zero_point_scalar) // clip(%3, a_min=%min_scalar, a_max=%max_scalar) BufferCreator buffer_creator; - tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8)); - tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8)); + tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(input_bits)); + tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(filter_bits)); if (bias_add_call) { - buffer_creator.CreateBufferVar("bias", DataType::Handle(32)); + buffer_creator.CreateBufferVar("bias", DataType::Handle(bias_bits)); } - tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(output_bits)); // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 @@ -341,8 +354,13 @@ class RelayToTIRVisitor : public MixedModeMutator { clip_min = clip_attrs->a_min; clip_max = clip_attrs->a_max; } else { - clip_min = -128; - clip_max = 127; + if (dtype_bits == 8) { + clip_min = std::numeric_limits::min(); + clip_max = std::numeric_limits::max(); + } else { + clip_min = std::numeric_limits::min(); + clip_max = std::numeric_limits::max(); + } } double quantized_multiplier = @@ -366,7 +384,10 @@ class RelayToTIRVisitor : public MixedModeMutator { Array cmsisnn_output_shape{batch_size, 1, 1, out_channels}; - tvm::Array call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter}; + std::string cmsisnn_api = + dtype_bits == 16 ? "arm_fully_connected_s16" : "arm_fully_connected_s8"; + + tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, filter}; if (bias_add_call) { call_ext_args.push_back(buffer_creator.GetBufferVar("bias")); } diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index ae9f195ca509..b5c5058ddbc0 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -115,7 +115,8 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { cmsis_func_name == "arm_depthwise_conv_wrapper_s8" || cmsis_func_name == "arm_depthwise_conv_wrapper_s16") { EmitConv2D(op); - } else if (cmsis_func_name == "arm_fully_connected_s8") { + } else if (cmsis_func_name == "arm_fully_connected_s8" || + cmsis_func_name == "arm_fully_connected_s16") { EmitFullyConnected(op); } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_max_pool_s8") { EmitPool2D(op); diff --git a/tests/python/contrib/test_cmsisnn/test_fully_connected.py b/tests/python/contrib/test_cmsisnn/test_fully_connected.py index 6fa1cc687f81..3b220eb42c9b 100644 --- a/tests/python/contrib/test_cmsisnn/test_fully_connected.py +++ b/tests/python/contrib/test_cmsisnn/test_fully_connected.py @@ -32,6 +32,7 @@ assert_partitioned_function, assert_no_external_function, create_test_runner, + get_kernel_bias_dtype, ) @@ -46,6 +47,7 @@ def make_model( output_scale, dtype, kernel_dtype, + bias_dtype, out_channels, enable_bias, relu_type="NONE", @@ -70,11 +72,11 @@ def make_model( input_scale=relay.const(input_scale, "float32"), kernel_scale=relay.const(kernel_scale, "float32"), units=out_channels, - out_dtype="int32", + out_dtype=bias_dtype, ) - bias = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) - bias_const = relay.const(bias, "int32") + bias = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype=bias_dtype)) + bias_const = relay.const(bias, bias_dtype) last_op = relay.nn.bias_add(dense, bias_const) if enable_bias else dense requant_input_sc = input_scale * kernel_scale last_op = relay.qnn.op.requantize( @@ -91,6 +93,7 @@ def make_model( @tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("dtype", ["int8", "int16"]) @pytest.mark.parametrize("in_shape", [(2, 28), (1, 64)]) @pytest.mark.parametrize("out_channels", [12, 128]) @pytest.mark.parametrize("enable_bias", [False, True]) @@ -101,7 +104,8 @@ def make_model( @pytest.mark.parametrize( "compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")] ) -def test_op_int8( +def test_ops( + dtype, in_shape, enable_bias, input_zero_point, @@ -115,7 +119,7 @@ def test_op_int8( interface_api = "c" use_unpacked_api = True - dtype = "int8" + kernel_dtype, bias_dtype = get_kernel_bias_dtype(dtype) kernel_zero_point = 0 kernel_shape = [out_channels, in_shape[1]] conv2d_kernel_shape = (1, 1, kernel_shape[0], kernel_shape[1]) @@ -140,7 +144,8 @@ def test_op_int8( output_zero_point, output_scale, dtype, - dtype, + kernel_dtype, + bias_dtype, out_channels, enable_bias, ) @@ -170,13 +175,15 @@ def test_op_int8( def parameterize_for_invalid_model(test): """Generates parameters for non int8 inputs to fully connected layer""" - in_dtype = ["uint8", "int8"] + in_dtype = ["uint8", "int8", "int16"] kernel_dtype = ["uint8", "int8"] kernel_zero_point = [-33, 10, 0] all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point) all_combinations = filter( lambda parameters: not ( - parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 + (parameters[0] == "int8" or parameters[0] == "int16") + and parameters[1] == "int8" + and parameters[2] == 0 ), all_combinations, ) @@ -199,6 +206,7 @@ def test_invalid_parameters( input_scale = 1 input_zero_point = 24 kernel_scale = [0.11, 0.0237] + _, bias_dtype = get_kernel_bias_dtype(in_dtype) kernel_shape = [out_channels, in_shape[1]] conv2d_kernel_shape = [1, 1, kernel_shape[0], kernel_shape[1]] @@ -223,6 +231,7 @@ def test_invalid_parameters( output_scale=output_scale, dtype=in_dtype, kernel_dtype=kernel_dtype, + bias_dtype=bias_dtype, out_channels=out_channels, enable_bias=True, )