diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index d54c6db09ba..f1e76c87bf9 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -2865,7 +2865,7 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, __m128 _tt1 = _mm256_extractf128_ps(_absmax_avx, 1); __m128 _absmax0 = _mm_unpacklo_ps(_tt0, _tt1); __m128 _absmax1 = _mm_unpackhi_ps(_tt0, _tt1); - _absmax_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_absmax0), _absmax1, 1); + _absmax_avx = combine4x2_ps(_absmax0, _absmax1); __m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx); __m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx); _mm256_store_ps(ps, _scale); @@ -8450,14 +8450,14 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& __m256 _ccf = _mm256_loadu_ps(pC + c_hstep * 15); transpose8x8_ps(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); transpose8x8_ps(_cc8, _cc9, _cca, _ccb, _ccc, _ccd, _cce, _ccf); - _c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc0), _cc8, 1); - _c1 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc1), _cc9, 1); - _c2 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc2), _cca, 1); - _c3 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc3), _ccb, 1); - _c4 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc4), _ccc, 1); - _c5 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc5), _ccd, 1); - _c6 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc6), _cce, 1); - _c7 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc7), _ccf, 1); + _c0 = combine8x2_ps(_cc0, _cc8); + _c1 = combine8x2_ps(_cc1, _cc9); + _c2 = combine8x2_ps(_cc2, _cca); + _c3 = combine8x2_ps(_cc3, _ccb); + _c4 = combine8x2_ps(_cc4, _ccc); + _c5 = combine8x2_ps(_cc5, _ccd); + _c6 = combine8x2_ps(_cc6, _cce); + _c7 = combine8x2_ps(_cc7, _ccf); pC += 8; } if (beta == 1.f) @@ -8768,19 +8768,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _MM_TRANSPOSE4_PS(_cc8, _cc9, _cca, _ccb); _MM_TRANSPOSE4_PS(_ccc, _ccd, _cce, _ccf); - __m256 _cc04 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc4, 1); - __m256 _cc15 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc5, 1); - __m256 _cc26 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc6, 1); - __m256 _cc37 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc3), _cc7, 1); - __m256 _cc8c = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc8), _ccc, 1); - __m256 _cc9d = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc9), _ccd, 1); - __m256 _ccae = _mm256_insertf128_ps(_mm256_castps128_ps256(_cca), _cce, 1); - __m256 _ccbf = _mm256_insertf128_ps(_mm256_castps128_ps256(_ccb), _ccf, 1); - - _c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc04), _cc8c, 1); - _c1 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc15), _cc9d, 1); - _c2 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc26), _ccae, 1); - _c3 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc37), _ccbf, 1); + _c0 = combine4x4_ps(_cc0, _cc4, _cc8, _ccc); + _c1 = combine4x4_ps(_cc1, _cc5, _cc9, _ccd); + _c2 = combine4x4_ps(_cc2, _cc6, _cca, _cce); + _c3 = combine4x4_ps(_cc3, _cc7, _ccb, _ccf); pC += 4; } @@ -9019,12 +9010,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 8 + 4); __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 12); __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 12 + 4); - __m256 _cc02 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc2, 1); - __m256 _cc46 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc4), _cc6, 1); - __m256 _cc13 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc3, 1); - __m256 _cc57 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc5), _cc7, 1); - _c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc02), _cc46, 1); - _c1 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc13), _cc57, 1); + _c0 = combine4x4_ps(_cc0, _cc2, _cc4, _cc6); + _c1 = combine4x4_ps(_cc1, _cc3, _cc5, _cc7); pC += 8; } else // if (c_elempack == 1) @@ -9134,7 +9121,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { __m256 _cc0 = _mm256_loadu_ps(pC); __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 8); - _c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc0), _cc1, 1); + _c0 = combine8x2_ps(_cc0, _cc1); pC += 8; } else if (c_elempack == 4) @@ -9143,9 +9130,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 8); __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 12); - __m256 _cc01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc1, 1); - __m256 _cc23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc3, 1); - _c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc01), _cc23, 1); + _c0 = combine4x4_ps(_cc0, _cc1, _cc2, _cc3); pC += 4; } else // if (c_elempack == 1) @@ -10242,10 +10227,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3); _MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7); - _c0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc4, 1); - _c1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc5, 1); - _c2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc6, 1); - _c3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc3), _cc7, 1); + _c0 = combine4x2_ps(_cc0, _cc4); + _c1 = combine4x2_ps(_cc1, _cc5); + _c2 = combine4x2_ps(_cc2, _cc6); + _c3 = combine4x2_ps(_cc3, _cc7); pC += 4; } @@ -10546,7 +10531,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #else __m128i _f0l = _mm_load_si128((const __m128i*)pp); __m128i _f0h = _mm_load_si128((const __m128i*)pp1); - __m256 _f0 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_f0l), _f0h, 1)); + __m256 _f0 = _mm256_cvtepi32_ps(combine4x2_epi32(_f0l, _f0h)); pp += 4; pp1 += 4; #endif @@ -10574,7 +10559,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { __m128 _cc0 = _mm_loadu_ps(pC); __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); - _c0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc1, 1); + _c0 = combine4x2_ps(_cc0, _cc1); pC += 4; } else // if (c_elempack == 1) @@ -12841,7 +12826,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); - __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB0 = combine8x2_epi32(_pB, _pB); __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); @@ -12888,7 +12873,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1230 5674 1230 5674 // 4567 0123 4567 0123 // 5674 1230 5674 1230 - __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pBB), _pBB, 1); + __m512i _pB0 = combine8x2_epi32(_pBB, _pBB); __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); @@ -12915,7 +12900,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); - __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pB0 = combine4x2_epi32(_pB, _pB); __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); @@ -13247,7 +13232,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); - __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); @@ -13266,7 +13251,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); - __m512i _w_shift00 = _mm512_inserti32x8(_mm512_castsi256_si512(_w_shift0), _w_shift0, 1); + __m512i _w_shift00 = combine8x2_epi32(_w_shift0, _w_shift0); __m512i _w_shift11 = _mm512_shuffle_epi32(_w_shift00, _MM_PERM_BADC); _sum0 = _mm512_sub_epi32(_sum0, _w_shift00); _sum1 = _mm512_sub_epi32(_sum1, _w_shift00); @@ -13289,7 +13274,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 0123 4567 // 2301 6745 2301 6745 - __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); // 0123 4567 89ab cdef @@ -13320,7 +13305,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _pA = _mm_cvtepi8_epi16(_pA); __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); - __m256i _pA00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m256i _pA00 = combine4x2_epi32(_pA, _pA); __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); @@ -13541,7 +13526,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m128i _pB = _mm_loadu_si128((const __m128i*)pB); - __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pB0 = combine4x2_epi32(_pB, _pB); __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); #if __AVXVNNIINT8__ @@ -13989,7 +13974,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); - __m256i _pA00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA0), _pA0, 1); + __m256i _pA00 = combine4x2_epi32(_pA0, _pA0); __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB23 = _mm256_shuffle_epi32(_pB01, _MM_SHUFFLE(0, 3, 2, 1)); #if __AVXVNNIINT8__ @@ -14010,7 +13995,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); - __m256i _w_shift00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w_shift0), _w_shift0, 1); + __m256i _w_shift00 = combine4x2_epi32(_w_shift0, _w_shift0); __m256i _w_shift11 = _mm256_shuffle_epi32(_w_shift00, _MM_SHUFFLE(1, 0, 3, 2)); _sum0 = _mm256_sub_epi32(_sum0, _w_shift00); _sum1 = _mm256_sub_epi32(_sum1, _w_shift11);