From 86d1d84642e01532163a113f0c5a7ee16355a57d Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 22 Apr 2024 23:35:02 -0400 Subject: [PATCH 1/7] basic avx implementation --- sgemm.cpp | 64 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 531e12af361cc..b20b1115c328c 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -1,6 +1,3 @@ -// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- -// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi -// // Copyright 2024 Mozilla Foundation // // Permission is hereby granted, free of charge, to any person obtaining @@ -586,15 +583,15 @@ class tinyBLAS_Q0_ARM { }; #endif // __ARM_FEATURE_DOTPROD -#if defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) template -class tinyBLAS_Q0_AVX2 { +class tinyBLAS_Q0_AVX { public: - tinyBLAS_Q0_AVX2(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, - int ith, int nth) + tinyBLAS_Q0_AVX(int k, + const TA *A, int lda, + const TB *B, int ldb, + TC *C, int ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -732,9 +729,9 @@ class tinyBLAS_Q0_AVX2 { for (int i = 0; i < RM; ++i) Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + updot(signepi8(load(A + lda * (ii + i) + l), load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + signepi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))), Cv[j][i]); for (int j = 0; j < RN; ++j) @@ -748,24 +745,55 @@ class tinyBLAS_Q0_AVX2 { } inline __m256i load(const block_q4_0 *b) { +#if defined(__AVX2__) return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); +#else + const __m128i dn0 = _mm256_extractf128_si256(denibble(b->qs), 0); + const __m128i dn1 = _mm256_extractf128_si256(denibble(b->qs), 1); + return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8))); +#endif } inline __m256 updot(__m256i u, __m256i s) { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); -#else +#elif defined(__AVX2__) res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#else + const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0)); + const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1)); + const __m128i onefill = _mm_set1_epi16(1); + res = MM256_SET_M128I(_mm_madd_epi16(onefill, usMaddubs1), _mm_madd_epi16(onefill, usMaddubs0)); #endif return _mm256_cvtepi32_ps(res); } +#if defined(__AVX2__) + inline __m256i signepi8(__m256i a, __m256i b) { + return _mm256_sign_epi8(a, b); + } +#else + inline __m256i signepi8(__m256i a, __m256i b) { + const __m128i a0 = _mm256_extractf128_si256(a, 0); + const __m128i a1 = _mm256_extractf128_si256(a, 1); + const __m128i b0 = _mm256_extractf128_si256(b, 0); + const __m128i b1 = _mm256_extractf128_si256(b, 1); + return MM256_SET_M128I(_mm_sign_epi8(a1, b1), _mm_sign_epi8(a0, b0)); + } +#endif + static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); +#if defined(__AVX2__) return _mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)); +#else + const __m128i maskedLow = _mm_and_si128(_mm_set1_epi8(15), x); + const __m128i maskedHigh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + return MM256_SET_M128I(maskedHigh, maskedLow); +#endif } const TA *const A; @@ -778,7 +806,7 @@ class tinyBLAS_Q0_AVX2 { const int ith; const int nth; }; -#endif // __AVX2__ +#endif // __AVX__ } // namespace @@ -932,8 +960,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) return false; -#if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -956,8 +984,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, case GGML_TYPE_Q4_0: { if (Btype != GGML_TYPE_Q8_0) return false; -#if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, From 257391aae304882c03eee4edebf55c03dcfb1d02 Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 22 Apr 2024 23:48:07 -0400 Subject: [PATCH 2/7] style --- sgemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index b20b1115c328c..5fd18549f4ace 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -763,8 +763,8 @@ class tinyBLAS_Q0_AVX { #else const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0)); const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1)); - const __m128i onefill = _mm_set1_epi16(1); - res = MM256_SET_M128I(_mm_madd_epi16(onefill, usMaddubs1), _mm_madd_epi16(onefill, usMaddubs0)); + const __m128i oneFill = _mm_set1_epi16(1); + res = MM256_SET_M128I(_mm_madd_epi16(oneFill, usMaddubs1), _mm_madd_epi16(oneFill, usMaddubs0)); #endif return _mm256_cvtepi32_ps(res); } From 9facb0f07a99383ae0eef8de2f044ba52902649b Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 23 Apr 2024 23:46:49 -0400 Subject: [PATCH 3/7] combine denibble with load --- sgemm.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 5fd18549f4ace..059674d1d7477 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -726,14 +726,18 @@ class tinyBLAS_Q0_AVX { __m256 Cv[RN][RM] = {}; for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - updot(signepi8(load(A + lda * (ii + i) + l), + for (int i = 0; i < RM; ++i) { + __m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l), load(A + lda * (ii + i) + l)), signepi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), - Cv[j][i]); + load(A + lda * (ii + i) + l))); + //_mm256i ali = load(A + lda * (ii + i) + l; + //_mm256i blj = load(B + ldb * (jj + j) + l; + Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + udTmp, + Cv[j][i]); + } for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); @@ -748,8 +752,10 @@ class tinyBLAS_Q0_AVX { #if defined(__AVX2__) return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); #else - const __m128i dn0 = _mm256_extractf128_si256(denibble(b->qs), 0); - const __m128i dn1 = _mm256_extractf128_si256(denibble(b->qs), 1); + __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + const __m128i dn0 = _mm_and_si128(_mm_set1_epi8(15), x); + const __m128i dn1 = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8))); #endif } @@ -785,15 +791,9 @@ class tinyBLAS_Q0_AVX { static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); -#if defined(__AVX2__) return _mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)); -#else - const __m128i maskedLow = _mm_and_si128(_mm_set1_epi8(15), x); - const __m128i maskedHigh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); - return MM256_SET_M128I(maskedHigh, maskedLow); -#endif } const TA *const A; From dee9566dc7e7d38a0c43cbed8c5e5678b8d4e46a Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:22:38 -0400 Subject: [PATCH 4/7] reduce 256 to 128 (and back!) conversions --- sgemm.cpp | 51 ++++++++++++++++++++++++--------------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 059674d1d7477..f87652c829075 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -727,17 +727,33 @@ class tinyBLAS_Q0_AVX { for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) { - __m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - signepi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))); - //_mm256i ali = load(A + lda * (ii + i) + l; - //_mm256i blj = load(B + ldb * (jj + j) + l; +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 0); + __m128i ali1 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 1); + __m128i blj0 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 0); + __m128i blj1 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 1); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), udTmp, Cv[j][i]); - } + } for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); @@ -764,31 +780,12 @@ class tinyBLAS_Q0_AVX { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); -#elif defined(__AVX2__) - res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #else - const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0)); - const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1)); - const __m128i oneFill = _mm_set1_epi16(1); - res = MM256_SET_M128I(_mm_madd_epi16(oneFill, usMaddubs1), _mm_madd_epi16(oneFill, usMaddubs0)); + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif return _mm256_cvtepi32_ps(res); } -#if defined(__AVX2__) - inline __m256i signepi8(__m256i a, __m256i b) { - return _mm256_sign_epi8(a, b); - } -#else - inline __m256i signepi8(__m256i a, __m256i b) { - const __m128i a0 = _mm256_extractf128_si256(a, 0); - const __m128i a1 = _mm256_extractf128_si256(a, 1); - const __m128i b0 = _mm256_extractf128_si256(b, 0); - const __m128i b1 = _mm256_extractf128_si256(b, 1); - return MM256_SET_M128I(_mm_sign_epi8(a1, b1), _mm_sign_epi8(a0, b0)); - } -#endif - static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); return _mm256_and_si256(_mm256_set1_epi8(15), From 063a31f7a880ca9810484b36090b47e3933d9202 Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 24 Apr 2024 23:00:02 -0400 Subject: [PATCH 5/7] sse load --- sgemm.cpp | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index f87652c829075..c0eb998bc2844 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -733,10 +733,10 @@ class tinyBLAS_Q0_AVX { _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))); #else - __m128i ali0 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 0); - __m128i ali1 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 1); - __m128i blj0 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 0); - __m128i blj1 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 1); + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); @@ -764,16 +764,26 @@ class tinyBLAS_Q0_AVX { return _mm256_loadu_si256((const __m256i *)b->qs); } + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } + + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + inline __m256i load(const block_q4_0 *b) { -#if defined(__AVX2__) return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); -#else - __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); - const __m128i dn0 = _mm_and_si128(_mm_set1_epi8(15), x); - const __m128i dn1 = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + } - return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8))); -#endif + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); } inline __m256 updot(__m256i u, __m256i s) { From fb80f13cd4a02887d719969e72e28e7bef45f702 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 25 Apr 2024 04:03:29 +0000 Subject: [PATCH 6/7] Update sgemm.cpp --- sgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgemm.cpp b/sgemm.cpp index c0eb998bc2844..0f58e949e6cb8 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -737,7 +737,7 @@ class tinyBLAS_Q0_AVX { __m128i ali1 = load1(A + lda * (ii + i) + l); __m128i blj0 = load0(B + ldb * (jj + j) + l); __m128i blj1 = load1(B + ldb * (jj + j) + l); - + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); From ae0b5ea7ae569c6e4782a91abc9dc02746535bfe Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 29 Apr 2024 22:45:59 -0400 Subject: [PATCH 7/7] oops oops --- sgemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgemm.cpp b/sgemm.cpp index 5e301b323d6f5..40ba9d7e9a7b7 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -725,7 +725,7 @@ class tinyBLAS_Q0_AVX { __m256 Cv[RN][RM] = {}; for (int64_t l = 0; l < k; ++l) for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) + for (int64_t i = 0; i < RM; ++i) { #if defined(__AVX2__) __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), load(A + lda * (ii + i) + l)),