From 15ebb6193197a1ba9cb789a4828b6601a092a062 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 29 Nov 2024 03:25:41 +0000 Subject: [PATCH] comp avxvnni --- CMakeLists.txt | 4 +- src/layer/x86/gemm_int8.h | 78 ++++++++++++++-------------- src/layer/x86/x86_usability.h | 96 +++++++++++++++++++++-------------- 3 files changed, 99 insertions(+), 79 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bf0e9f20fb8..097c191a066 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -504,7 +504,7 @@ else() check_cxx_compiler_flag("/arch:AVX512" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) @@ -541,7 +541,7 @@ else() check_cxx_compiler_flag("/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 089709c9225..c59cd7e789d 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -18490,14 +18490,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); - _sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA0); - _sum2 = _mm256_dpbusd_epi32(_sum2, _pB0, _pA1); - _sum3 = _mm256_dpbusd_epi32(_sum3, _pB1, _pA1); - _sum4 = _mm256_dpbusd_epi32(_sum4, _pB2, _pA0); - _sum5 = _mm256_dpbusd_epi32(_sum5, _pB3, _pA0); - _sum6 = _mm256_dpbusd_epi32(_sum6, _pB2, _pA1); - _sum7 = _mm256_dpbusd_epi32(_sum7, _pB3, _pA1); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm256_comp_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm256_comp_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm256_comp_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm256_comp_dpbusd_epi32(_sum7, _pB3, _pA1); pA += 32; pB += 32; } @@ -18646,10 +18646,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - _sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA0); - _sum2 = _mm256_dpbusd_epi32(_sum2, _pB0, _pA1); - _sum3 = _mm256_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB1, _pA1); pA += 32; pB += 16; } @@ -18752,8 +18752,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); - _sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA); - _sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA); pA += 32; pB += 8; } @@ -18836,7 +18836,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); - _sum0 = _mm256_dpbusd_epi32(_sum0, _pB, _pA); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB, _pA); pA += 32; pB += 4; } @@ -19057,14 +19057,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0); - _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0); - _sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1); - _sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1); - _sum4 = _mm_dpbusd_epi32(_sum4, _pB2, _pA0); - _sum5 = _mm_dpbusd_epi32(_sum5, _pB3, _pA0); - _sum6 = _mm_dpbusd_epi32(_sum6, _pB2, _pA1); - _sum7 = _mm_dpbusd_epi32(_sum7, _pB3, _pA1); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm_comp_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm_comp_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm_comp_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm_comp_dpbusd_epi32(_sum7, _pB3, _pA1); pA += 16; pB += 32; } @@ -19255,10 +19255,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0); - _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0); - _sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1); - _sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1); pA += 16; pB += 16; } @@ -19399,8 +19399,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA = _mm_loadu_si128((const __m128i*)pA); __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA); - _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA); pA += 16; pB += 8; } @@ -19511,7 +19511,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB, _pA); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB, _pA); pA += 16; pB += 4; } @@ -19711,10 +19711,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0); - _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0); - _sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1); - _sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1); pA += 8; pB += 32; } @@ -19837,8 +19837,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA); - _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA); pA += 8; pB += 16; } @@ -20177,8 +20177,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16)); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA); - _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA); pA += 4; pB += 32; } @@ -20265,7 +20265,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); __m128i _pB = _mm_loadu_si128((const __m128i*)pB); - _sum0 = _mm_dpbusd_epi32(_sum0, _pB, _pA); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB, _pA); pA += 4; pB += 16; } diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index e3d1d11fbc5..53bc40c3c32 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -267,83 +267,83 @@ static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128& return _v; } -#ifndef __FMA__ -static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_add_ps(_mm_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); -} -static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); -} -#else static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { +#if __FMA__ return _mm_fmadd_ps(_a, _b, _c); +#else + return _mm_add_ps(_mm_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { // return -a * b + c +#if __FMA__ return _mm_fnmadd_ps(_a, _b, _c); +#else + return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); +#endif } + static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) { +#if __FMA__ return _mm_fmsub_ps(_a, _b, _c); +#else + return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) { +#if __FMA__ return _mm_fnmsub_ps(_a, _b, _c); +#else + return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); +#endif } -#endif // !__FMA__ #if __AVX__ -#ifndef __FMA__ -static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); -} -static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); -} -#else static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return a * b + c +#if __FMA__ return _mm256_fmadd_ps(_a, _b, _c); +#else + return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return -a * b + c +#if __FMA__ return _mm256_fnmadd_ps(_a, _b, _c); +#else + return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); +#endif } + static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return a * b - c +#if __FMA__ return _mm256_fmsub_ps(_a, _b, _c); +#else + return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return -(a * b) - c +#if __FMA__ return _mm256_fnmsub_ps(_a, _b, _c); -} +#else + return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); #endif +} static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(const __m256& a, const __m256& b, float c) { @@ -841,6 +841,26 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& } #if __AVX2__ +#if __AVX512VNNI__ || __AVXVNNI__ +static NCNN_FORCEINLINE __m128i _mm_comp_dpbusd_epi32(__m128i src, __m128i a, __m128i b) +{ +#if __AVX512VNNI__ + return _mm_dpbusd_epi32(src, a, b); +#else + return _mm_dpbusd_avx_epi32(src, a, b); +#endif +} + +static NCNN_FORCEINLINE __m256i _mm256_comp_dpbusd_epi32(__m256i src, __m256i a, __m256i b) +{ +#if __AVX512VNNI__ + return _mm256_dpbusd_epi32(src, a, b); +#else + return _mm256_dpbusd_avx_epi32(src, a, b); +#endif +} +#endif // __AVX512VNNI__ || __AVXVNNI__ + static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) { __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1);