Skip to content

Commit

Permalink
[microTVM] Replace arm_nnsupportfunctions.h with arm_acle.h (#13363)
Browse files Browse the repository at this point in the history
* [microTVM] Replace arm_nnsupportfunctions.h with arm_acle.h

This attempts to replace the CMSIS-NN header with a more portable
alternative and avoid dependence on CMSIS

* Remove CMSIS __STATIC_FORCEINLINE macro

* Replace more intrinsics with ACLE variants

* Use builtins for intrinsics missing in older GCC

* Re-use common_includes to propagate shared functions

The packing definitions aren't implemented as ACLE intrinsics nor is there a simple way to convince a C compiler to generate them.

* Properly align memory access for

Introduce `memcpy` to explain to the compiler that we're changing
the alignment of `int16_t` to `int32_t`. What this appears to actually
do is encourage the compiler to use three loads rather than one double
load plus a regular load.

The padded array is aligned as an `int16_t`, it isn't guaranteed to
behave like an `int32_t` aligned array. One of the side effects of the
type punning from `int16_t*` to `int32_t*` is that we're effectively
lying to the compiler that this is correctly aligned and it can use
instructions which load multiple `int32_t`s at the same time - this does
not work 😿

Co-authored-by: Ashutosh Parkhi <[email protected]>
  • Loading branch information
Mousius and ashutosh-arm authored Jan 9, 2023
1 parent a435cbb commit 6bc72bb
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 42 deletions.
6 changes: 3 additions & 3 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def sum_impl(N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif // __cplusplus
__STATIC_FORCEINLINE int32_t sum16_reset_{uniq_id}(
__attribute__((always_inline)) static inline int32_t sum16_reset_{uniq_id}(
int16_t *res) {{
*res = (int16_t)0;
return 0;
Expand All @@ -110,7 +110,7 @@ def sum_impl(N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t sum16_{N}_{uniq_id}(
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
int16_t *arr,
int16_t *res16,
long arr_offset,
Expand All @@ -129,7 +129,7 @@ def sum_impl(N, uniq_id):
}}
for ( int i = 0; i < n / 2; ++ i ) {{
res = __SMLAD(*p32, 0x00010001, res);
res = __smlad(*p32, 0x00010001, res);
++ p32;
}}
Expand Down
34 changes: 33 additions & 1 deletion python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,42 @@
#include <stdlib.h>
#include <string.h>
#include <arm_nnsupportfunctions.h>
#include <arm_acle.h>
#include <tvm/runtime/crt/error_codes.h>
#ifndef ARM_CPU_INTRINSICS_EXIST
#define ARM_CPU_INTRINSICS_EXIST
__attribute__((always_inline)) uint32_t __ror(uint32_t op1, uint32_t op2)
{
op2 %= 32U;
if (op2 == 0U)
{
return op1;
}
return (op1 >> op2) | (op1 << (32U - op2));
}
#define __pkhbt(ARG1,ARG2,ARG3) \
__extension__ \
({ \
uint32_t __RES, __ARG1 = (ARG1), __ARG2 = (ARG2); \
__asm("pkhbt %0, %1, %2, lsl %3" : "=r" (__RES) : "r" (__ARG1), "r" (__ARG2), "I" (ARG3) ); \
__RES; \
})
#define __pkhtb(ARG1,ARG2,ARG3) \
__extension__ \
({ \
uint32_t __RES, __ARG1 = (ARG1), __ARG2 = (ARG2); \
if (ARG3 == 0) \
__asm("pkhtb %0, %1, %2" : "=r" (__RES) : "r" (__ARG1), "r" (__ARG2) ); \
else \
__asm("pkhtb %0, %1, %2, asr %3" : "=r" (__RES) : "r" (__ARG1), "r" (__ARG2), "I" (ARG3) ); \
__RES; \
})
#endif
"""

MICRO_WORD_LENGTH_BITS = 32
Expand Down
66 changes: 43 additions & 23 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,30 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
cc_code = (
common.common_includes
+ f"""
#ifndef ARM_CPU_MPROFILE_READ_AND_PAD_EXISTS
#define ARM_CPU_MPROFILE_READ_AND_PAD_EXISTS
__attribute__((always_inline)) static inline const int8_t *read_and_pad(const int8_t *source, int32_t *out1, int32_t *out2)
{{
int32_t inA;
memcpy(&inA, source, 4);
source += 4;
int32_t inAbuf1 = __sxtb16(__ror((uint32_t)inA, 8));
int32_t inAbuf2 = __sxtb16(inA);
*out2 = (int32_t)(__pkhtb(inAbuf1, inAbuf2, 16));
*out1 = (int32_t)(__pkhbt(inAbuf2, inAbuf1, 16));
return source;
}}
#endif
"""
+ f"""
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
int K,
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
Expand Down Expand Up @@ -180,7 +198,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
Expand All @@ -201,7 +219,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int16_t bb_pad[{bb_pad_size}];
Expand All @@ -226,7 +244,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
int32_t *bb_ptr = (int32_t *) &bb_pad[j*{K}];
int32_t sum = 0;
for (int l = 0; l < 2 * ({K} / 4); l++) {{
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
sum = __smlad(*aa_ptr, *bb_ptr, sum);
++ aa_ptr; ++ bb_ptr;
}}
// NOTE: this is the line where `*_body` differs from `*_update`. here
Expand All @@ -246,7 +264,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
int K,
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
Expand Down Expand Up @@ -289,7 +307,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
Expand All @@ -307,7 +325,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int16_t bb_pad[{bb_pad_size}];
Expand All @@ -332,7 +350,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
int32_t *bb_ptr = (int32_t *) &bb_pad[j*{K}];
int32_t sum = 0;
for (int l = 0; l < 2 * ({K} / 4); l++) {{
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
sum = __smlad(*aa_ptr, *bb_ptr, sum);
++ aa_ptr; ++ bb_ptr;
}}
cc[i*C_stride + j] += sum;
Expand All @@ -349,7 +367,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
int K,
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
Expand All @@ -367,7 +385,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
Expand All @@ -388,7 +406,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t retcode = 0;
Expand All @@ -405,13 +423,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t *aa_ptr = (int32_t *) &aa[i*A_stride];
int32_t *bb_ptr = (int32_t *) &bb[j*B_stride];
int32_t aa_vector[{K} / 2];
int32_t bb_vector[{K} / 2];
memcpy(&aa_vector, &aa[i * A_stride], sizeof(aa_vector));
memcpy(&bb_vector, &bb[j * B_stride], sizeof(bb_vector));
int32_t sum = 0;
for (int l = 0; l < {K} / 2; l++) {{
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
++ aa_ptr; ++ bb_ptr;
sum = __smlad(aa_vector[l], bb_vector[l], sum);
}}
// NOTE: this is the line where `*_body` differs from `*_update`. here
// we're *setting* the result, instead of accumulating, because we know
Expand All @@ -430,7 +449,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
int K,
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
Expand All @@ -448,7 +467,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
Expand All @@ -466,7 +485,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
int A_stride, int B_stride, int C_stride) {{
int32_t retcode = 0;
Expand All @@ -478,13 +497,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t *aa_ptr = (int32_t *) &aa[i*A_stride];
int32_t *bb_ptr = (int32_t *) &bb[j*B_stride];
int32_t aa_vector[{K} / 2];
int32_t bb_vector[{K} / 2];
memcpy(&aa_vector, &aa[i * A_stride], sizeof(aa_vector));
memcpy(&bb_vector, &bb[j * B_stride], sizeof(bb_vector));
int32_t sum = 0;
for (int l = 0; l < {K} / 2; l++) {{
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
++ aa_ptr; ++ bb_ptr;
sum = __smlad(aa_vector[l], bb_vector[l], sum);
}}
cc[i*C_stride + j] += sum;
}}
Expand All @@ -500,7 +520,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
cc[i*C_stride + j] = 0;
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def max_impl(uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t max8_reset_{uniq_id}(
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
int8_t *res,
int N) {{
memset(res, (int8_t)-128, N * sizeof(*res));
Expand All @@ -104,7 +104,7 @@ def max_impl(uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t max8_loop_{uniq_id}(
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
int8_t *arg,
int8_t *res,
int N) {{
Expand All @@ -117,7 +117,7 @@ def max_impl(uniq_id):
#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t max8_{uniq_id}(
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
int8_t *arg,
int8_t *res,
int N) {{
Expand Down Expand Up @@ -146,8 +146,8 @@ def max_impl(uniq_id):
for ( int i = 0; i < N / 4; ++ i ) {{
int32_t arg32 = *parg32 ++;
int32_t res32 = *pres32;
__SSUB8(arg32, res32);
res32 = __SEL(arg32, res32);
__ssub8(arg32, res32);
res32 = __sel(arg32, res32);
*pres32 ++ = res32;
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import textwrap

from tvm import te, tir
from .common import num_simd_lanes_per_word
from .common import num_simd_lanes_per_word, common_includes


def _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix):
Expand Down Expand Up @@ -107,10 +107,8 @@ def multi_channel_convolve_impl(in_dtype, *args) -> str:
def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix):
return textwrap.dedent(
(
f"""
#include <stdint.h>
#include <arm_nnsupportfunctions.h>
common_includes
+ f"""
// __SXTB16(_ROR(X, Y)) is combined into one assembly instruction
#define TVMGEN_QUAD_INT8_CHANNEL_REARRANGE_SUM_DSP( \
Expand All @@ -120,13 +118,13 @@ def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, ke
\
uint32_t kernel_c3210 = *arranged_kernel++; \
\
uint32_t tensor_c20 = __SXTB16(tensor_c3210); \
uint32_t kernel_c20 = __SXTB16(kernel_c3210); \
uint32_t tensor_c20 = __sxtb16(tensor_c3210); \
uint32_t kernel_c20 = __sxtb16(kernel_c3210); \
sum_c0 = __builtin_arm_smlabb(tensor_c20, kernel_c20, sum_c0); \
sum_c2 = __builtin_arm_smlatt(tensor_c20, kernel_c20, sum_c2); \
\
uint32_t tensor_c31 = __SXTB16(__ROR(tensor_c3210, 8)); \
uint32_t kernel_c31 = __SXTB16(__ROR(kernel_c3210, 8)); \
uint32_t tensor_c31 = __sxtb16(__ror(tensor_c3210, 8)); \
uint32_t kernel_c31 = __sxtb16(__ror(kernel_c3210, 8)); \
sum_c1 = __builtin_arm_smlabb(tensor_c31, kernel_c31, sum_c1); \
sum_c3 = __builtin_arm_smlatt(tensor_c31, kernel_c31, sum_c3); \
}}
Expand Down Expand Up @@ -172,7 +170,8 @@ def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, ke
def _dual_int16_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix):
return textwrap.dedent(
(
f"""
common_includes
+ f"""
#include <stdint.h>
/* We do four channels at once to get this speed boost. */
Expand Down

0 comments on commit 6bc72bb

Please sign in to comment.