Skip to content

Commit

Permalink
divps
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Nov 14, 2024
1 parent 459cf4c commit 1bec95e
Showing 1 changed file with 20 additions and 88 deletions.
108 changes: 20 additions & 88 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2194,8 +2194,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
#if __AVX512F__
if (elempack == 16)
{
// __m512 _v127 = _mm512_set1_ps(127.f);
// __m512 _v127_B_scale = _mm512_set1_ps(v127_B_scale);
__m512 _v127 = _mm512_set1_ps(127.f);
__m512 _v127_B_scale = _mm512_set1_ps(v127_B_scale);
for (int ii = 0; ii + 15 < max_ii; ii += 16)
{
const float* p0 = (const float*)A + (i + ii) * A_hstep;
Expand All @@ -2209,47 +2209,11 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
p0 += 16;
}

// __m512 _scale = _mm512_div_ps(_v127, _absmax0);
// __m512 _out_descale = _mm512_div_ps(_absmax0, _v127_B_scale);

// _mm512_store_ps(ps, _scale);
// _mm512_store_ps(pods, _out_descale);

float absmax[16];
_mm512_storeu_ps(absmax, _absmax0);

ps[0] = 127.f / absmax[0];
ps[1] = 127.f / absmax[1];
ps[2] = 127.f / absmax[2];
ps[3] = 127.f / absmax[3];
ps[4] = 127.f / absmax[4];
ps[5] = 127.f / absmax[5];
ps[6] = 127.f / absmax[6];
ps[7] = 127.f / absmax[7];
ps[8] = 127.f / absmax[8];
ps[9] = 127.f / absmax[9];
ps[10] = 127.f / absmax[10];
ps[11] = 127.f / absmax[11];
ps[12] = 127.f / absmax[12];
ps[13] = 127.f / absmax[13];
ps[14] = 127.f / absmax[14];
ps[15] = 127.f / absmax[15];
pods[0] = absmax[0] / v127_B_scale;
pods[1] = absmax[1] / v127_B_scale;
pods[2] = absmax[2] / v127_B_scale;
pods[3] = absmax[3] / v127_B_scale;
pods[4] = absmax[4] / v127_B_scale;
pods[5] = absmax[5] / v127_B_scale;
pods[6] = absmax[6] / v127_B_scale;
pods[7] = absmax[7] / v127_B_scale;
pods[8] = absmax[8] / v127_B_scale;
pods[9] = absmax[9] / v127_B_scale;
pods[10] = absmax[10] / v127_B_scale;
pods[11] = absmax[11] / v127_B_scale;
pods[12] = absmax[12] / v127_B_scale;
pods[13] = absmax[13] / v127_B_scale;
pods[14] = absmax[14] / v127_B_scale;
pods[15] = absmax[15] / v127_B_scale;
__m512 _scale = _mm512_div_ps(_v127, _absmax0);
__m512 _out_descale = _mm512_div_ps(_absmax0, _v127_B_scale);

_mm512_store_ps(ps, _scale);
_mm512_store_ps(pods, _out_descale);

ps += 16;
pods += 16;
Expand All @@ -2258,8 +2222,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
#endif // __AVX512F__
if (elempack == 8)
{
// __m256 _v127 = _mm256_set1_ps(127.f);
// __m256 _v127_B_scale = _mm256_set1_ps(v127_B_scale);
__m256 _v127 = _mm256_set1_ps(127.f);
__m256 _v127_B_scale = _mm256_set1_ps(v127_B_scale);
for (int ii = 0; ii + 7 < max_ii; ii += 8)
{
const float* p0 = (const float*)A + (i + ii) * A_hstep;
Expand All @@ -2273,31 +2237,11 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
p0 += 8;
}

// __m256 _scale = _mm256_div_ps(_v127, _absmax0);
// __m256 _out_descale = _mm256_div_ps(_absmax0, _v127_B_scale);

// _mm256_store_ps(ps, _scale);
// _mm256_store_ps(pods, _out_descale);

float absmax[8];
_mm256_storeu_ps(absmax, _absmax0);

ps[0] = 127.f / absmax[0];
ps[1] = 127.f / absmax[1];
ps[2] = 127.f / absmax[2];
ps[3] = 127.f / absmax[3];
ps[4] = 127.f / absmax[4];
ps[5] = 127.f / absmax[5];
ps[6] = 127.f / absmax[6];
ps[7] = 127.f / absmax[7];
pods[0] = absmax[0] / v127_B_scale;
pods[1] = absmax[1] / v127_B_scale;
pods[2] = absmax[2] / v127_B_scale;
pods[3] = absmax[3] / v127_B_scale;
pods[4] = absmax[4] / v127_B_scale;
pods[5] = absmax[5] / v127_B_scale;
pods[6] = absmax[6] / v127_B_scale;
pods[7] = absmax[7] / v127_B_scale;
__m256 _scale = _mm256_div_ps(_v127, _absmax0);
__m256 _out_descale = _mm256_div_ps(_absmax0, _v127_B_scale);

_mm256_store_ps(ps, _scale);
_mm256_store_ps(pods, _out_descale);

ps += 8;
pods += 8;
Expand All @@ -2306,8 +2250,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
#endif // __AVX__
if (elempack == 4)
{
// __m128 _v127 = _mm_set1_ps(127.f);
// __m128 _v127_B_scale = _mm_set1_ps(v127_B_scale);
__m128 _v127 = _mm_set1_ps(127.f);
__m128 _v127_B_scale = _mm_set1_ps(v127_B_scale);
for (int ii = 0; ii + 3 < max_ii; ii += 4)
{
const float* p0 = (const float*)A + (i + ii) * A_hstep;
Expand All @@ -2321,23 +2265,11 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
p0 += 4;
}

// __m128 _scale = _mm_div_ps(_v127, _absmax0);
// __m128 _out_descale = _mm_div_ps(_absmax0, _v127_B_scale);

// _mm_store_ps(ps, _scale);
// _mm_store_ps(pods, _out_descale);

float absmax[4];
_mm_storeu_ps(absmax, _absmax0);
__m128 _scale = _mm_div_ps(_v127, _absmax0);
__m128 _out_descale = _mm_div_ps(_absmax0, _v127_B_scale);

ps[0] = 127.f / absmax[0];
ps[1] = 127.f / absmax[1];
ps[2] = 127.f / absmax[2];
ps[3] = 127.f / absmax[3];
pods[0] = absmax[0] / v127_B_scale;
pods[1] = absmax[1] / v127_B_scale;
pods[2] = absmax[2] / v127_B_scale;
pods[3] = absmax[3] / v127_B_scale;
_mm_store_ps(ps, _scale);
_mm_store_ps(pods, _out_descale);

ps += 4;
pods += 4;
Expand Down

0 comments on commit 1bec95e

Please sign in to comment.