diff --git a/src/layer/x86/dequantize_x86.cpp b/src/layer/x86/dequantize_x86.cpp index 5e4f8b55310..272b820da90 100644 --- a/src/layer/x86/dequantize_x86.cpp +++ b/src/layer/x86/dequantize_x86.cpp @@ -40,45 +40,27 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con // NCNN_LOGE("dequantize %d %d %d %d", scale_data_size, bias_data_size, elemcount, elempack); - const float* scale_ptr = scale_data; - - float scale = 0.f; + float scale = scale_data[0]; #if __SSE2__ - __m128 _scale = _mm_setzero_ps(); + __m128 _scale = _mm_set1_ps(scale); #if __AVX__ - __m256 _scale_avx = _mm256_setzero_ps(); + __m256 _scale_avx = _mm256_set1_ps(scale); #if __AVX512F__ - __m512 _scale_avx512 = _mm512_setzero_ps(); + __m512 _scale_avx512 = _mm512_set1_ps(scale); #endif // __AVX512F__ #endif // __AVX__ -#endif // __SSE2__ - - if (scale_data_size == 1 || elempack == 1) + if (scale_data_size > 1) { - scale = scale_ptr[0]; -#if __SSE2__ - _scale = _mm_set1_ps(scale); -#if __AVX__ - _scale_avx = _mm256_set1_ps(scale); -#if __AVX512F__ - _scale_avx512 = _mm512_set1_ps(scale); -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - } - else - { -#if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { - _scale_avx512 = _mm512_loadu_ps(scale_ptr); + _scale_avx512 = _mm512_loadu_ps((const float*)scale_data); } #endif // __AVX512F__ if (elempack == 8) { - _scale_avx = _mm256_loadu_ps(scale_ptr); + _scale_avx = _mm256_loadu_ps((const float*)scale_data); #if __AVX512F__ _scale_avx512 = combine8x2_ps(_scale_avx, _scale_avx); #endif // __AVX512F__ @@ -86,7 +68,7 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con #endif // __AVX__ if (elempack == 4) { - _scale = _mm_loadu_ps(scale_ptr); + _scale = _mm_loadu_ps((const float*)scale_data); #if __AVX__ _scale_avx = combine4x2_ps(_scale, _scale); #if __AVX512F__ @@ -94,8 +76,8 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con #endif // __AVX512F__ #endif // __AVX__ } -#endif // __SSE2__ } +#endif // __SSE2__ if (bias_data_size == 0) { @@ -139,45 +121,27 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con } else { - const float* bias_ptr = bias_data; - - float bias = 0.f; -#if __SSE2__ - __m128 _bias = _mm_setzero_ps(); -#if __AVX__ - __m256 _bias_avx = _mm256_setzero_ps(); -#if __AVX512F__ - __m512 _bias_avx512 = _mm512_setzero_ps(); -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - - if (bias_data_size == 1 || elempack == 1) - { - bias = bias_ptr[0]; + float bias = bias_data[0]; #if __SSE2__ - _bias = _mm_set1_ps(bias); + __m128 _bias = _mm_set1_ps(bias); #if __AVX__ - _bias_avx = _mm256_set1_ps(bias); + __m256 _bias_avx = _mm256_set1_ps(bias); #if __AVX512F__ - _bias_avx512 = _mm512_set1_ps(bias); + __m512 _bias_avx512 = _mm512_set1_ps(bias); #endif // __AVX512F__ #endif // __AVX__ -#endif // __SSE2__ - } - else + if (bias_data_size > 1) { -#if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { - _bias_avx512 = _mm512_loadu_ps(bias_ptr); + _bias_avx512 = _mm512_loadu_ps((const float*)bias_data); } #endif // __AVX512F__ if (elempack == 8) { - _bias_avx = _mm256_loadu_ps(bias_ptr); + _bias_avx = _mm256_loadu_ps((const float*)bias_data); #if __AVX512F__ _bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx); #endif // __AVX512F__ @@ -185,7 +149,7 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con #endif // __AVX__ if (elempack == 4) { - _bias = _mm_loadu_ps(bias_ptr); + _bias = _mm_loadu_ps((const float*)bias_data); #if __AVX__ _bias_avx = combine4x2_ps(_bias, _bias); #if __AVX512F__ @@ -193,8 +157,8 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con #endif // __AVX512F__ #endif // __AVX__ } -#endif // __SSE2__ } +#endif // __SSE2__ int i = 0; #if __SSE2__