From 01eadf28ae1404a72c02ff92c7b7c43819825fc3 Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 1 Dec 2024 19:40:27 +0800 Subject: [PATCH] general rvv gemm --- src/layer/riscv/gemm_riscv.cpp | 2952 +++++--------------------------- src/layer/riscv/gemm_riscv.h | 3 +- 2 files changed, 445 insertions(+), 2510 deletions(-) diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp index d49fd8376be..0c9862d55d0 100644 --- a/src/layer/riscv/gemm_riscv.cpp +++ b/src/layer/riscv/gemm_riscv.cpp @@ -33,153 +33,46 @@ Gemm_riscv::Gemm_riscv() support_inplace = false; nT = 0; -#if __riscv_vector - // When processing float data, - // even if the current hardware provides vector registers of more than 128 bits, - // vl=4 is still used, even though this will waste the width of the vector register. - vl = __riscv_vsetvlmax_e32m1(); - vl = vl >= 4 ? 4 : vl; -#else - vl = 0; -#endif // __riscv_vector } -static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); +#endif + const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + // NCNN_LOGE("pack_A_tile %d", elempack); + float* pp = AT; int ii = 0; #if __riscv_vector - for (; ii + 7 < max_ii; ii += 8) + for (; ii + (packn - 1) < max_ii; ii += packn) { - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; - const float* p1 = (const float*)A + (i + ii + 4) * A_hstep + k * 4; + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * packn; for (int kk = 0; kk < max_kk; kk++) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p1, vl), vl); - pp += 8; - p0 += 4; - p1 += 4; + pp += packn; + p0 += packn; } } if (elempack == 1) { const float* p0 = (const float*)A + (i + ii) * A_hstep + k; - const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; - const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; - const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; - const float* p4 = (const float*)A + (i + ii + 4) * A_hstep + k; - const float* p5 = (const float*)A + (i + ii + 5) * A_hstep + k; - const float* p6 = (const float*)A + (i + ii + 6) * A_hstep + k; - const float* p7 = (const float*)A + (i + ii + 7) * A_hstep + k; - - int kk = 0; - for (; kk + 7 < max_kk; kk += 8) - { - vfloat32m1_t _r0l = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t _r0h = __riscv_vle32_v_f32m1(p0 + 4, vl); - vfloat32m1_t _r1l = __riscv_vle32_v_f32m1(p1, vl); - vfloat32m1_t _r1h = __riscv_vle32_v_f32m1(p1 + 4, vl); - vfloat32m1_t _r2l = __riscv_vle32_v_f32m1(p2, vl); - vfloat32m1_t _r2h = __riscv_vle32_v_f32m1(p2 + 4, vl); - vfloat32m1_t _r3l = __riscv_vle32_v_f32m1(p3, vl); - vfloat32m1_t _r3h = __riscv_vle32_v_f32m1(p3 + 4, vl); - vfloat32m1_t _r4l = __riscv_vle32_v_f32m1(p4, vl); - vfloat32m1_t _r4h = __riscv_vle32_v_f32m1(p4 + 4, vl); - vfloat32m1_t _r5l = __riscv_vle32_v_f32m1(p5, vl); - vfloat32m1_t _r5h = __riscv_vle32_v_f32m1(p5 + 4, vl); - vfloat32m1_t _r6l = __riscv_vle32_v_f32m1(p6, vl); - vfloat32m1_t _r6h = __riscv_vle32_v_f32m1(p6 + 4, vl); - vfloat32m1_t _r7l = __riscv_vle32_v_f32m1(p7, vl); - vfloat32m1_t _r7h = __riscv_vle32_v_f32m1(p7 + 4, vl); - - __riscv_vsseg8e32_v_f32m1x8(pp, __riscv_vcreate_v_f32m1x8(_r0l, _r1l, _r2l, _r3l, _r4l, _r5l, _r6l, _r7l), vl); - __riscv_vsseg8e32_v_f32m1x8(pp + 32, __riscv_vcreate_v_f32m1x8(_r0h, _r1h, _r2h, _r3h, _r4h, _r5h, _r6h, _r7h), vl); - - pp += 64; - p0 += 8; - p1 += 8; - p2 += 8; - p3 += 8; - p4 += 8; - p5 += 8; - p6 += 8; - p7 += 8; - } - for (; kk < max_kk; kk++) - { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp[4] = p4[0]; - pp[5] = p5[0]; - pp[6] = p6[0]; - pp[7] = p7[0]; - pp += 8; - p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - } - } - } - for (; ii + 3 < max_ii; ii += 4) - { - if (elempack == 4) - { - const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; for (int kk = 0; kk < max_kk; kk++) { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; - p0 += 4; - } - } - if (elempack == 1) - { - const float* p0 = (const float*)A + (i + ii) * A_hstep + k; - const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; - const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; - const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1_t v0 = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p1, vl); - vfloat32m1_t v2 = __riscv_vle32_v_f32m1(p2, vl); - vfloat32m1_t v3 = __riscv_vle32_v_f32m1(p3, vl); - __riscv_vsseg4e32_v_f32m1x4(pp, __riscv_vcreate_v_f32m1x4(v0, v1, v2, v3), vl); - pp += 16; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - } - for (; kk < max_kk; kk++) - { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp += 4; + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0, A_hstep * sizeof(float), vl), vl); + pp += packn; p0++; - p1++; - p2++; - p3++; } } } @@ -193,14 +86,14 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max int kk = 0; #if __riscv_vector - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { vfloat32m1_t v0 = __riscv_vle32_v_f32m1(p0, vl); vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p1, vl); __riscv_vsseg2e32_v_f32m1x2(pp, __riscv_vcreate_v_f32m1x2(v0, v1), vl); - pp += 8; - p0 += 4; - p1 += 4; + pp += packn * 2; + p0 += packn; + p1 += packn; } #endif // __riscv_vector for (; kk < max_kk; kk++) @@ -221,11 +114,11 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max int kk = 0; #if __riscv_vector - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; - p0 += 4; + pp += packn; + p0 += packn; } #endif // __riscv_vector for (; kk < max_kk; kk++) @@ -238,68 +131,38 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max } } -static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); +#endif + const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + // NCNN_LOGE("transpose_pack_A_tile %d", elempack); + float* pp = AT; int ii = 0; #if __riscv_vector - for (; ii + 7 < max_ii; ii += 8) - { - if (elempack == 4) - { - const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1x4_t _r0 = __riscv_vlseg4e32_v_f32m1x4(p0, vl); - vfloat32m1x4_t _r1 = __riscv_vlseg4e32_v_f32m1x4(p0 + 16, vl); - __riscv_vse32_v_f32m1(pp, __riscv_vget_v_f32m1x4_f32m1(_r0, 0), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vget_v_f32m1x4_f32m1(_r1, 0), vl); - __riscv_vse32_v_f32m1(pp + 4 * 2, __riscv_vget_v_f32m1x4_f32m1(_r0, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 3, __riscv_vget_v_f32m1x4_f32m1(_r1, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 4, __riscv_vget_v_f32m1x4_f32m1(_r0, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 5, __riscv_vget_v_f32m1x4_f32m1(_r1, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 6, __riscv_vget_v_f32m1x4_f32m1(_r0, 3), vl); - __riscv_vse32_v_f32m1(pp + 4 * 7, __riscv_vget_v_f32m1x4_f32m1(_r1, 3), vl); - pp += 32; - p0 += A_hstep * 4; - } - } - if (elempack == 1) - { - const float* p0 = (const float*)A + k * A_hstep + (i + ii); - - int kk = 0; - for (; kk < max_kk; kk++) - { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p0 + 4, vl), vl); - pp += 8; - p0 += A_hstep; - } - } - } - for (; ii + 3 < max_ii; ii += 4) + for (; ii + (packn - 1) < max_ii; ii += packn) { - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * packn; int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { - vfloat32m1x4_t _r0 = __riscv_vlseg4e32_v_f32m1x4(p0, vl); - __riscv_vse32_v_f32m1(pp, __riscv_vget_v_f32m1x4_f32m1(_r0, 0), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vget_v_f32m1x4_f32m1(_r0, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 2, __riscv_vget_v_f32m1x4_f32m1(_r0, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 3, __riscv_vget_v_f32m1x4_f32m1(_r0, 3), vl); - pp += 16; - p0 += A_hstep * 4; + // transposeNxN + for (int l = 0; l < packn; l++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, packn * sizeof(float), vl), vl); + pp += packn; + } + p0 += A_hstep * packn; } } if (elempack == 1) @@ -310,7 +173,7 @@ static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int for (; kk < max_kk; kk++) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; + pp += packn; p0 += A_hstep; } } @@ -319,18 +182,18 @@ static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int for (; ii + 1 < max_ii; ii += 2) { #if __riscv_vector - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * packn; int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { vfloat32m1_t v0 = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p0 + 4, vl); + vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p0 + packn, vl); __riscv_vsseg2e32_v_f32m1x2(pp, __riscv_vcreate_v_f32m1x2(v0, v1), vl); - pp += 8; - p0 += A_hstep * 4; + pp += packn * 2; + p0 += A_hstep * packn; } } #endif // __riscv_vector @@ -351,16 +214,16 @@ static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int for (; ii < max_ii; ii += 1) { #if __riscv_vector - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * packn; int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; - p0 += A_hstep * 4; + pp += packn; + p0 += A_hstep * packn; } } #endif // __riscv_vector @@ -379,243 +242,44 @@ static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int } } -static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, size_t vl) +static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); +#endif + const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + // NCNN_LOGE("pack_B_tile %d", elempack); + float* pp = BT; int jj = 0; #if __riscv_vector - for (; jj + 11 < max_jj; jj += 12) - { - if (elempack == 4) - { - const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; - const float* p1 = (const float*)B + (j + jj + 4) * B_hstep + k * 4; - const float* p2 = (const float*)B + (j + jj + 8) * B_hstep + k * 4; - - for (int kk = 0; kk < max_kk; kk++) - { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p1, vl), vl); - __riscv_vse32_v_f32m1(pp + 8, __riscv_vle32_v_f32m1(p2, vl), vl); - pp += 12; - p0 += 4; - p1 += 4; - p2 += 4; - } - } - if (elempack == 1) - { - const float* p0 = (const float*)B + (j + jj) * B_hstep + k; - const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; - const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; - const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; - const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; - const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; - const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; - const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; - const float* p8 = (const float*)B + (j + jj + 8) * B_hstep + k; - const float* p9 = (const float*)B + (j + jj + 9) * B_hstep + k; - const float* pa = (const float*)B + (j + jj + 10) * B_hstep + k; - const float* pb = (const float*)B + (j + jj + 11) * B_hstep + k; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t _r1 = __riscv_vle32_v_f32m1(p1, vl); - vfloat32m1_t _r2 = __riscv_vle32_v_f32m1(p2, vl); - vfloat32m1_t _r3 = __riscv_vle32_v_f32m1(p3, vl); - vfloat32m1_t _r4 = __riscv_vle32_v_f32m1(p4, vl); - vfloat32m1_t _r5 = __riscv_vle32_v_f32m1(p5, vl); - vfloat32m1_t _r6 = __riscv_vle32_v_f32m1(p6, vl); - vfloat32m1_t _r7 = __riscv_vle32_v_f32m1(p7, vl); - vfloat32m1_t _r8 = __riscv_vle32_v_f32m1(p8, vl); - vfloat32m1_t _r9 = __riscv_vle32_v_f32m1(p9, vl); - vfloat32m1_t _ra = __riscv_vle32_v_f32m1(pa, vl); - vfloat32m1_t _rb = __riscv_vle32_v_f32m1(pb, vl); - - transpose4x4_ps(_r0, _r1, _r2, _r3, vl); - transpose4x4_ps(_r4, _r5, _r6, _r7, vl); - transpose4x4_ps(_r8, _r9, _ra, _rb, vl); - - __riscv_vse32_v_f32m1(pp, _r0, vl); - __riscv_vse32_v_f32m1(pp + 4, _r4, vl); - __riscv_vse32_v_f32m1(pp + 4 * 2, _r8, vl); - __riscv_vse32_v_f32m1(pp + 4 * 3, _r1, vl); - __riscv_vse32_v_f32m1(pp + 4 * 4, _r5, vl); - __riscv_vse32_v_f32m1(pp + 4 * 5, _r9, vl); - __riscv_vse32_v_f32m1(pp + 4 * 6, _r2, vl); - __riscv_vse32_v_f32m1(pp + 4 * 7, _r6, vl); - __riscv_vse32_v_f32m1(pp + 4 * 8, _ra, vl); - __riscv_vse32_v_f32m1(pp + 4 * 9, _r3, vl); - __riscv_vse32_v_f32m1(pp + 4 * 10, _r7, vl); - __riscv_vse32_v_f32m1(pp + 4 * 11, _rb, vl); - pp += 48; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; - p8 += 4; - p9 += 4; - pa += 4; - pb += 4; - } - for (; kk < max_kk; kk++) - { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp[4] = p4[0]; - pp[5] = p5[0]; - pp[6] = p6[0]; - pp[7] = p7[0]; - pp[8] = p8[0]; - pp[9] = p9[0]; - pp[10] = pa[0]; - pp[11] = pb[0]; - pp += 12; - p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - p8++; - p9++; - pa++; - pb++; - } - } - } - for (; jj + 7 < max_jj; jj += 8) + for (; jj + (packn - 1) < max_jj; jj += packn) { - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; - const float* p1 = (const float*)B + (j + jj + 4) * B_hstep + k * 4; + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * packn; for (int kk = 0; kk < max_kk; kk++) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p1, vl), vl); - pp += 8; - p0 += 4; - p1 += 4; + pp += packn; + p0 += packn; } } if (elempack == 1) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k; - const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; - const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; - const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; - const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; - const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; - const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; - const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t _r1 = __riscv_vle32_v_f32m1(p1, vl); - vfloat32m1_t _r2 = __riscv_vle32_v_f32m1(p2, vl); - vfloat32m1_t _r3 = __riscv_vle32_v_f32m1(p3, vl); - vfloat32m1_t _r4 = __riscv_vle32_v_f32m1(p4, vl); - vfloat32m1_t _r5 = __riscv_vle32_v_f32m1(p5, vl); - vfloat32m1_t _r6 = __riscv_vle32_v_f32m1(p6, vl); - vfloat32m1_t _r7 = __riscv_vle32_v_f32m1(p7, vl); - - __riscv_vsseg8e32_v_f32m1x8(pp, __riscv_vcreate_v_f32m1x8(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7), vl); - - pp += 32; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; - } - for (; kk < max_kk; kk++) - { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp[4] = p4[0]; - pp[5] = p5[0]; - pp[6] = p6[0]; - pp[7] = p7[0]; - pp += 8; - p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - } - } - } - for (; jj + 3 < max_jj; jj += 4) - { - if (elempack == 4) - { - const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; for (int kk = 0; kk < max_kk; kk++) { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; - p0 += 4; - } - } - if (elempack == 1) - { - const float* p0 = (const float*)B + (j + jj) * B_hstep + k; - const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; - const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; - const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1_t v0 = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p1, vl); - vfloat32m1_t v2 = __riscv_vle32_v_f32m1(p2, vl); - vfloat32m1_t v3 = __riscv_vle32_v_f32m1(p3, vl); - __riscv_vsseg4e32_v_f32m1x4(pp, __riscv_vcreate_v_f32m1x4(v0, v1, v2, v3), vl); - pp += 16; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - } - for (; kk < max_kk; kk++) - { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp += 4; + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0, B_hstep * sizeof(float), vl), vl); + pp += packn; p0++; - p1++; - p2++; - p3++; } } } @@ -629,14 +293,14 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max int kk = 0; #if __riscv_vector - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { vfloat32m1_t v0 = __riscv_vle32_v_f32m1(p0, vl); vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p1, vl); __riscv_vsseg2e32_v_f32m1x2(pp, __riscv_vcreate_v_f32m1x2(v0, v1), vl); - pp += 8; - p0 += 4; - p1 += 4; + pp += packn * 2; + p0 += packn; + p1 += packn; } #endif // __riscv_vector for (; kk < max_kk; kk++) @@ -657,11 +321,11 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max int kk = 0; #if __riscv_vector - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; - p0 += 4; + pp += packn; + p0 += packn; } #endif // __riscv_vector for (; kk < max_kk; kk++) @@ -674,111 +338,38 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max } } -static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, size_t vl) +static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); +#endif + const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + // NCNN_LOGE("transpose_pack_B_tile %d", elempack); + float* pp = BT; int jj = 0; #if __riscv_vector - for (; jj + 11 < max_jj; jj += 12) - { - if (elempack == 4) - { - const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1x4_t _r0 = __riscv_vlseg4e32_v_f32m1x4(p0, vl); - vfloat32m1x4_t _r1 = __riscv_vlseg4e32_v_f32m1x4(p0 + 16, vl); - vfloat32m1x4_t _r2 = __riscv_vlseg4e32_v_f32m1x4(p0 + 32, vl); - __riscv_vse32_v_f32m1(pp, __riscv_vget_v_f32m1x4_f32m1(_r0, 0), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vget_v_f32m1x4_f32m1(_r1, 0), vl); - __riscv_vse32_v_f32m1(pp + 4 * 2, __riscv_vget_v_f32m1x4_f32m1(_r2, 0), vl); - __riscv_vse32_v_f32m1(pp + 4 * 3, __riscv_vget_v_f32m1x4_f32m1(_r0, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 4, __riscv_vget_v_f32m1x4_f32m1(_r1, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 5, __riscv_vget_v_f32m1x4_f32m1(_r2, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 6, __riscv_vget_v_f32m1x4_f32m1(_r0, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 7, __riscv_vget_v_f32m1x4_f32m1(_r1, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 8, __riscv_vget_v_f32m1x4_f32m1(_r2, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 9, __riscv_vget_v_f32m1x4_f32m1(_r0, 3), vl); - __riscv_vse32_v_f32m1(pp + 4 * 10, __riscv_vget_v_f32m1x4_f32m1(_r1, 3), vl); - __riscv_vse32_v_f32m1(pp + 4 * 11, __riscv_vget_v_f32m1x4_f32m1(_r2, 3), vl); - pp += 48; - p0 += B_hstep * 4; - } - } - if (elempack == 1) - { - const float* p0 = (const float*)B + k * B_hstep + (j + jj); - - int kk = 0; - for (; kk < max_kk; kk++) - { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p0 + 4, vl), vl); - __riscv_vse32_v_f32m1(pp + 8, __riscv_vle32_v_f32m1(p0 + 8, vl), vl); - pp += 12; - p0 += B_hstep; - } - } - } - for (; jj + 7 < max_jj; jj += 8) - { - if (elempack == 4) - { - const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; - - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1x4_t _r0 = __riscv_vlseg4e32_v_f32m1x4(p0, vl); - vfloat32m1x4_t _r1 = __riscv_vlseg4e32_v_f32m1x4(p0 + 16, vl); - __riscv_vse32_v_f32m1(pp, __riscv_vget_v_f32m1x4_f32m1(_r0, 0), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vget_v_f32m1x4_f32m1(_r1, 0), vl); - __riscv_vse32_v_f32m1(pp + 4 * 2, __riscv_vget_v_f32m1x4_f32m1(_r0, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 3, __riscv_vget_v_f32m1x4_f32m1(_r1, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 4, __riscv_vget_v_f32m1x4_f32m1(_r0, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 5, __riscv_vget_v_f32m1x4_f32m1(_r1, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 6, __riscv_vget_v_f32m1x4_f32m1(_r0, 3), vl); - __riscv_vse32_v_f32m1(pp + 4 * 7, __riscv_vget_v_f32m1x4_f32m1(_r1, 3), vl); - pp += 32; - p0 += B_hstep * 4; - } - } - if (elempack == 1) - { - const float* p0 = (const float*)B + k * B_hstep + (j + jj); - - int kk = 0; - for (; kk < max_kk; kk++) - { - __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vle32_v_f32m1(p0 + 4, vl), vl); - pp += 8; - p0 += B_hstep; - } - } - } - for (; jj + 3 < max_jj; jj += 4) + for (; jj + (packn - 1) < max_jj; jj += packn) { - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * packn; int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { - vfloat32m1x4_t _r0 = __riscv_vlseg4e32_v_f32m1x4(p0, vl); - __riscv_vse32_v_f32m1(pp, __riscv_vget_v_f32m1x4_f32m1(_r0, 0), vl); - __riscv_vse32_v_f32m1(pp + 4, __riscv_vget_v_f32m1x4_f32m1(_r0, 1), vl); - __riscv_vse32_v_f32m1(pp + 4 * 2, __riscv_vget_v_f32m1x4_f32m1(_r0, 2), vl); - __riscv_vse32_v_f32m1(pp + 4 * 3, __riscv_vget_v_f32m1x4_f32m1(_r0, 3), vl); - pp += 16; - p0 += B_hstep * 4; + // transposeNxN + for (int l = 0; l < packn; l++) + { + __riscv_vse32_v_f32m1(pp, __riscv_vlse32_v_f32m1(p0 + l, packn * sizeof(float), vl), vl); + pp += packn; + } + p0 += B_hstep * packn; } } if (elempack == 1) @@ -789,7 +380,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int for (; kk < max_kk; kk++) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; + pp += packn; p0 += B_hstep; } } @@ -798,18 +389,18 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int for (; jj + 1 < max_jj; jj += 2) { #if __riscv_vector - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * packn; int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { vfloat32m1_t v0 = __riscv_vle32_v_f32m1(p0, vl); - vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p0 + 4, vl); + vfloat32m1_t v1 = __riscv_vle32_v_f32m1(p0 + packn, vl); __riscv_vsseg2e32_v_f32m1x2(pp, __riscv_vcreate_v_f32m1x2(v0, v1), vl); - pp += 8; - p0 += B_hstep * 4; + pp += packn * 2; + p0 += B_hstep * packn; } } #endif // __riscv_vector @@ -830,16 +421,16 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int for (; jj < max_jj; jj += 1) { #if __riscv_vector - if (elempack == 4) + if (elempack == packn) { - const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * packn; int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + for (; kk + (packn - 1) < max_kk; kk += packn) { __riscv_vse32_v_f32m1(pp, __riscv_vle32_v_f32m1(p0, vl), vl); - pp += 4; - p0 += B_hstep * 4; + pp += packn; + p0 += B_hstep * packn; } } #endif // __riscv_vector @@ -858,67 +449,37 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int } } -static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj, size_t vl) +static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj) { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); +#endif + const int out_elempack = top_blob.elempack; const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + // NCNN_LOGE("transpose_unpack_output_tile %d", out_elempack); + const float* pp = topT; int ii = 0; #if __riscv_vector - for (; ii + 7 < max_ii; ii += 8) - { - if (out_elempack == 4) - { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; - - for (int jj = 0; jj + 3 < max_jj; jj += 4) - { - vfloat32m1_t v0 = __riscv_vle32_v_f32m1(pp, vl); - vfloat32m1_t v1 = __riscv_vle32_v_f32m1(pp + 4, vl); - vfloat32m1_t v2 = __riscv_vle32_v_f32m1(pp + 8, vl); - vfloat32m1_t v3 = __riscv_vle32_v_f32m1(pp + 12, vl); - vfloat32m1_t v4 = __riscv_vle32_v_f32m1(pp + 16, vl); - vfloat32m1_t v5 = __riscv_vle32_v_f32m1(pp + 20, vl); - vfloat32m1_t v6 = __riscv_vle32_v_f32m1(pp + 24, vl); - vfloat32m1_t v7 = __riscv_vle32_v_f32m1(pp + 28, vl); - __riscv_vsseg4e32_v_f32m1x4(p0, __riscv_vcreate_v_f32m1x4(v0, v2, v4, v6), vl); - __riscv_vsseg4e32_v_f32m1x4(p0 + 16, __riscv_vcreate_v_f32m1x4(v1, v3, v5, v7), vl); - pp += 32; - p0 += out_hstep * 4; - } - } - if (out_elempack == 1) - { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii); - - for (int jj = 0; jj < max_jj; jj += 1) - { - vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(pp, vl); - vfloat32m1_t _r1 = __riscv_vle32_v_f32m1(pp + 4, vl); - __riscv_vse32_v_f32m1(p0, _r0, vl); - __riscv_vse32_v_f32m1(p0 + 4, _r1, vl); - pp += 8; - p0 += out_hstep; - } - } - } - for (; ii + 3 < max_ii; ii += 4) + for (; ii + (packn - 1) < max_ii; ii += packn) { - if (out_elempack == 4) + if (out_elempack == packn) { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * packn; - for (int jj = 0; jj + 3 < max_jj; jj += 4) + for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) { - vfloat32m1_t v0 = __riscv_vle32_v_f32m1(pp, vl); - vfloat32m1_t v1 = __riscv_vle32_v_f32m1(pp + 4, vl); - vfloat32m1_t v2 = __riscv_vle32_v_f32m1(pp + 8, vl); - vfloat32m1_t v3 = __riscv_vle32_v_f32m1(pp + 12, vl); - __riscv_vsseg4e32_v_f32m1x4(p0, __riscv_vcreate_v_f32m1x4(v0, v1, v2, v3), vl); - pp += 16; - p0 += out_hstep * 4; + // transposeNxN + for (int l = 0; l < packn; l++) + { + __riscv_vsse32_v_f32m1(p0 + l, packn * sizeof(float), __riscv_vle32_v_f32m1(pp, vl), vl); + pp += packn; + } + p0 += out_hstep * packn; } } if (out_elempack == 1) @@ -929,7 +490,7 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, { vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(pp, vl); __riscv_vse32_v_f32m1(p0, _r0, vl); - pp += 4; + pp += packn; p0 += out_hstep; } } @@ -938,22 +499,17 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, for (; ii + 1 < max_ii; ii += 2) { #if __riscv_vector - if (out_elempack == 4) + if (out_elempack == packn) { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * packn; - for (int jj = 0; jj + 3 < max_jj; jj += 4) + for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) { - p0[0] = pp[0]; - p0[1] = pp[2]; - p0[2] = pp[4]; - p0[3] = pp[6]; - p0[4] = pp[1]; - p0[5] = pp[3]; - p0[6] = pp[5]; - p0[7] = pp[7]; - pp += 8; - p0 += out_hstep * 4; + vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pp, vl); + __riscv_vse32_v_f32m1(p0, __riscv_vget_v_f32m1x2_f32m1(_s0, 0), vl); + __riscv_vse32_v_f32m1(p0 + packn, __riscv_vget_v_f32m1x2_f32m1(_s0, 1), vl); + pp += packn * 2; + p0 += out_hstep * packn; } } #endif // __riscv_vector @@ -973,16 +529,16 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, for (; ii < max_ii; ii += 1) { #if __riscv_vector - if (out_elempack == 4) + if (out_elempack == packn) { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * packn; - for (int jj = 0; jj + 3 < max_jj; jj += 4) + for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn) { vfloat32m1_t _r0 = __riscv_vle32_v_f32m1(pp, vl); __riscv_vse32_v_f32m1(p0, _r0, vl); - pp += 4; - p0 += out_hstep * 4; + pp += packn; + p0 += out_hstep * packn; } } #endif // __riscv_vector @@ -1000,8 +556,13 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, } } -static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end, size_t vl) +static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) { +#if __riscv_vector + const int packn = csrr_vlenb() / 4; + const size_t vl = __riscv_vsetvl_e32m1(packn); +#endif + const int out_elempack = top_blob.elempack; const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; @@ -1013,7 +574,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons int ii = 0; #if __riscv_vector - for (; ii + 7 < max_ii; ii += 8) + for (; ii + (packn - 1) < max_ii; ii += packn) { float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; @@ -1032,1517 +593,244 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons } int jj = 0; - for (; jj + 11 < max_jj; jj += 12) + for (; jj + (packn - 1) < max_jj; jj += packn) { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - vfloat32m1_t _sum10; - vfloat32m1_t _sum11; - vfloat32m1_t _sum20; - vfloat32m1_t _sum21; - vfloat32m1_t _sum30; - vfloat32m1_t _sum31; - vfloat32m1_t _sum40; - vfloat32m1_t _sum41; - vfloat32m1_t _sum50; - vfloat32m1_t _sum51; - vfloat32m1_t _sum60; - vfloat32m1_t _sum61; - vfloat32m1_t _sum70; - vfloat32m1_t _sum71; - vfloat32m1_t _sum80; - vfloat32m1_t _sum81; - vfloat32m1_t _sum90; - vfloat32m1_t _sum91; - vfloat32m1_t _suma0; - vfloat32m1_t _suma1; - vfloat32m1_t _sumb0; - vfloat32m1_t _sumb1; - - if (k == 0) + if (packn == 8) { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum10 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum11 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum20 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum21 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum30 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum31 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum40 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum41 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum50 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum51 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum60 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum61 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum70 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum71 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum80 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum81 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum90 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum91 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _suma0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _suma1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sumb0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sumb1 = __riscv_vfmv_v_f_f32m1(0.f, vl); + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; - if (pC) + if (k == 0) { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum20 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum21 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum30 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum31 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum40 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum41 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum50 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum51 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum60 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum61 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum70 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum71 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum80 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum81 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum90 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum91 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _suma0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _suma1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sumb0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sumb1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum10 = _sum00; - _sum11 = _sum01; - _sum20 = _sum00; - _sum21 = _sum01; - _sum30 = _sum00; - _sum31 = _sum01; - _sum40 = _sum00; - _sum41 = _sum01; - _sum50 = _sum00; - _sum51 = _sum01; - _sum60 = _sum00; - _sum61 = _sum01; - _sum70 = _sum00; - _sum71 = _sum01; - _sum80 = _sum00; - _sum81 = _sum01; - _sum90 = _sum00; - _sum91 = _sum01; - _suma0 = _sum00; - _suma1 = _sum01; - _sumb0 = _sum00; - _sumb1 = _sum01; - } - if (broadcast_type_C == 3) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(pC + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(pC + 4 * 3, vl); - _sum20 = __riscv_vle32_v_f32m1(pC + 4 * 4, vl); - _sum21 = __riscv_vle32_v_f32m1(pC + 4 * 5, vl); - _sum30 = __riscv_vle32_v_f32m1(pC + 4 * 6, vl); - _sum31 = __riscv_vle32_v_f32m1(pC + 4 * 7, vl); - _sum40 = __riscv_vle32_v_f32m1(pC + 4 * 8, vl); - _sum41 = __riscv_vle32_v_f32m1(pC + 4 * 9, vl); - _sum50 = __riscv_vle32_v_f32m1(pC + 4 * 10, vl); - _sum51 = __riscv_vle32_v_f32m1(pC + 4 * 11, vl); - _sum60 = __riscv_vle32_v_f32m1(pC + 4 * 12, vl); - _sum61 = __riscv_vle32_v_f32m1(pC + 4 * 13, vl); - _sum70 = __riscv_vle32_v_f32m1(pC + 4 * 14, vl); - _sum71 = __riscv_vle32_v_f32m1(pC + 4 * 15, vl); - _sum80 = __riscv_vle32_v_f32m1(pC + 4 * 16, vl); - _sum81 = __riscv_vle32_v_f32m1(pC + 4 * 17, vl); - _sum90 = __riscv_vle32_v_f32m1(pC + 4 * 18, vl); - _sum91 = __riscv_vle32_v_f32m1(pC + 4 * 19, vl); - _suma0 = __riscv_vle32_v_f32m1(pC + 4 * 20, vl); - _suma1 = __riscv_vle32_v_f32m1(pC + 4 * 21, vl); - _sumb0 = __riscv_vle32_v_f32m1(pC + 4 * 22, vl); - _sumb1 = __riscv_vle32_v_f32m1(pC + 4 * 23, vl); - pC += 96; + _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum4 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum5 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum6 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum7 = __riscv_vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); + _sum4 = __riscv_vle32_v_f32m1(pC + packn * 4, vl); + _sum5 = __riscv_vle32_v_f32m1(pC + packn * 5, vl); + _sum6 = __riscv_vle32_v_f32m1(pC + packn * 6, vl); + _sum7 = __riscv_vle32_v_f32m1(pC + packn * 7, vl); + pC += packn * 8; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); + _sum4 = __riscv_vfmv_v_f_f32m1(pC[4], vl); + _sum5 = __riscv_vfmv_v_f_f32m1(pC[5], vl); + _sum6 = __riscv_vfmv_v_f_f32m1(pC[6], vl); + _sum7 = __riscv_vfmv_v_f_f32m1(pC[7], vl); + pC += 8; + } } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum20 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum30 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - _sum40 = __riscv_vfmv_v_f_f32m1(pC[4], vl); - _sum50 = __riscv_vfmv_v_f_f32m1(pC[5], vl); - _sum60 = __riscv_vfmv_v_f_f32m1(pC[6], vl); - _sum70 = __riscv_vfmv_v_f_f32m1(pC[7], vl); - _sum80 = __riscv_vfmv_v_f_f32m1(pC[8], vl); - _sum90 = __riscv_vfmv_v_f_f32m1(pC[9], vl); - _suma0 = __riscv_vfmv_v_f_f32m1(pC[10], vl); - _sumb0 = __riscv_vfmv_v_f_f32m1(pC[11], vl); - _sum01 = _sum00; - _sum11 = _sum10; - _sum21 = _sum20; - _sum31 = _sum30; - _sum41 = _sum40; - _sum51 = _sum50; - _sum61 = _sum60; - _sum71 = _sum70; - _sum81 = _sum80; - _sum91 = _sum90; - _suma1 = _suma0; - _sumb1 = _sumb0; - pC += 12; - } - } - } - else - { - _sum00 = __riscv_vle32_v_f32m1(outptr, vl); - _sum01 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - _sum20 = __riscv_vle32_v_f32m1(outptr + 4 * 4, vl); - _sum21 = __riscv_vle32_v_f32m1(outptr + 4 * 5, vl); - _sum30 = __riscv_vle32_v_f32m1(outptr + 4 * 6, vl); - _sum31 = __riscv_vle32_v_f32m1(outptr + 4 * 7, vl); - _sum40 = __riscv_vle32_v_f32m1(outptr + 4 * 8, vl); - _sum41 = __riscv_vle32_v_f32m1(outptr + 4 * 9, vl); - _sum50 = __riscv_vle32_v_f32m1(outptr + 4 * 10, vl); - _sum51 = __riscv_vle32_v_f32m1(outptr + 4 * 11, vl); - _sum60 = __riscv_vle32_v_f32m1(outptr + 4 * 12, vl); - _sum61 = __riscv_vle32_v_f32m1(outptr + 4 * 13, vl); - _sum70 = __riscv_vle32_v_f32m1(outptr + 4 * 14, vl); - _sum71 = __riscv_vle32_v_f32m1(outptr + 4 * 15, vl); - _sum80 = __riscv_vle32_v_f32m1(outptr + 4 * 16, vl); - _sum81 = __riscv_vle32_v_f32m1(outptr + 4 * 17, vl); - _sum90 = __riscv_vle32_v_f32m1(outptr + 4 * 18, vl); - _sum91 = __riscv_vle32_v_f32m1(outptr + 4 * 19, vl); - _suma0 = __riscv_vle32_v_f32m1(outptr + 4 * 20, vl); - _suma1 = __riscv_vle32_v_f32m1(outptr + 4 * 21, vl); - _sumb0 = __riscv_vle32_v_f32m1(outptr + 4 * 22, vl); - _sumb1 = __riscv_vle32_v_f32m1(outptr + 4 * 23, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) - { - vfloat32m1_t _pA0 = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - _sum40 = __riscv_vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); - _sum41 = __riscv_vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); - _sum50 = __riscv_vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); - _sum51 = __riscv_vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); - _sum60 = __riscv_vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); - _sum61 = __riscv_vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); - _sum70 = __riscv_vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); - _sum71 = __riscv_vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); - _sum80 = __riscv_vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); - _sum81 = __riscv_vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); - _sum90 = __riscv_vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); - _sum91 = __riscv_vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); - _suma0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); - _suma1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); - _sumb0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); - _sumb1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); - - pA += 8; - pB += 12; - - _pA0 = __riscv_vle32_v_f32m1(pA, vl); - _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - _sum40 = __riscv_vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); - _sum41 = __riscv_vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); - _sum50 = __riscv_vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); - _sum51 = __riscv_vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); - _sum60 = __riscv_vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); - _sum61 = __riscv_vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); - _sum70 = __riscv_vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); - _sum71 = __riscv_vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); - _sum80 = __riscv_vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); - _sum81 = __riscv_vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); - _sum90 = __riscv_vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); - _sum91 = __riscv_vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); - _suma0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); - _suma1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); - _sumb0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); - _sumb1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); - - pA += 8; - pB += 12; - - _pA0 = __riscv_vle32_v_f32m1(pA, vl); - _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - _sum40 = __riscv_vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); - _sum41 = __riscv_vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); - _sum50 = __riscv_vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); - _sum51 = __riscv_vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); - _sum60 = __riscv_vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); - _sum61 = __riscv_vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); - _sum70 = __riscv_vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); - _sum71 = __riscv_vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); - _sum80 = __riscv_vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); - _sum81 = __riscv_vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); - _sum90 = __riscv_vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); - _sum91 = __riscv_vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); - _suma0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); - _suma1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); - _sumb0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); - _sumb1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); - - pA += 8; - pB += 12; - - _pA0 = __riscv_vle32_v_f32m1(pA, vl); - _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - _sum40 = __riscv_vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); - _sum41 = __riscv_vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); - _sum50 = __riscv_vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); - _sum51 = __riscv_vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); - _sum60 = __riscv_vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); - _sum61 = __riscv_vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); - _sum70 = __riscv_vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); - _sum71 = __riscv_vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); - _sum80 = __riscv_vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); - _sum81 = __riscv_vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); - _sum90 = __riscv_vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); - _sum91 = __riscv_vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); - _suma0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); - _suma1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); - _sumb0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); - _sumb1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); - - pA += 8; - pB += 12; - } - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA0 = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - _sum40 = __riscv_vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); - _sum41 = __riscv_vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); - _sum50 = __riscv_vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); - _sum51 = __riscv_vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); - _sum60 = __riscv_vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); - _sum61 = __riscv_vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); - _sum70 = __riscv_vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); - _sum71 = __riscv_vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); - _sum80 = __riscv_vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); - _sum81 = __riscv_vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); - _sum90 = __riscv_vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); - _sum91 = __riscv_vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); - _suma0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); - _suma1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); - _sumb0 = __riscv_vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); - _sumb1 = __riscv_vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); - - pA += 8; - pB += 12; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 4, _sum40, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 5, _sum50, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 6, _sum60, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 7, _sum70, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 8, _sum80, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 9, _sum90, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 10, _suma0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 11, _sumb0, vl); - - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 4, _sum41, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 5, _sum51, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 6, _sum61, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 7, _sum71, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 8, _sum81, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 9, _sum91, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 10, _suma1, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 11, _sumb1, vl); - - outptr0 += 48; } - if (out_elempack == 1) - { - transpose8x12_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, _sum40, _sum41, _sum50, _sum51, _sum60, _sum61, _sum70, _sum71, _sum80, _sum81, _sum90, _sum91, _suma0, _suma1, _sumb0, _sumb1, vl); - - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + 8, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 4, _sum20, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 8, _sum21, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2, _sum30, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum31, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2 + 8, _sum40, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3, _sum41, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum50, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3 + 8, _sum51, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum60, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum61, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 8, _sum70, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 5, _sum71, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 5 + 4, _sum80, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 5 + 8, _sum81, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 6, _sum90, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 6 + 4, _sum91, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 6 + 8, _suma0, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 7, _suma1, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 7 + 4, _sumb0, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 7 + 8, _sumb1, vl); - - outptr0 += 12; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum00, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 8, _sum40, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 9, _sum41, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 10, _sum50, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 11, _sum51, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 12, _sum60, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 13, _sum61, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 14, _sum70, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 15, _sum71, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 16, _sum80, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 17, _sum81, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 18, _sum90, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 19, _sum91, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 20, _suma0, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 21, _suma1, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 22, _sumb0, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 23, _sumb1, vl); - } - - outptr += 96; - } - for (; jj + 7 < max_jj; jj += 8) - { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - vfloat32m1_t _sum10; - vfloat32m1_t _sum11; - vfloat32m1_t _sum20; - vfloat32m1_t _sum21; - vfloat32m1_t _sum30; - vfloat32m1_t _sum31; - vfloat32m1_t _sum40; - vfloat32m1_t _sum41; - vfloat32m1_t _sum50; - vfloat32m1_t _sum51; - vfloat32m1_t _sum60; - vfloat32m1_t _sum61; - vfloat32m1_t _sum70; - vfloat32m1_t _sum71; - - if (k == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum10 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum11 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum20 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum21 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum30 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum31 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum40 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum41 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum50 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum51 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum60 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum61 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum70 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum71 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum20 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum21 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum30 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum31 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum40 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum41 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum50 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum51 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum60 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum61 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum70 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum71 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum10 = _sum00; - _sum11 = _sum01; - _sum20 = _sum00; - _sum21 = _sum01; - _sum30 = _sum00; - _sum31 = _sum01; - _sum40 = _sum00; - _sum41 = _sum01; - _sum50 = _sum00; - _sum51 = _sum01; - _sum60 = _sum00; - _sum61 = _sum01; - _sum70 = _sum00; - _sum71 = _sum01; - } - if (broadcast_type_C == 3) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(pC + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(pC + 4 * 3, vl); - _sum20 = __riscv_vle32_v_f32m1(pC + 4 * 4, vl); - _sum21 = __riscv_vle32_v_f32m1(pC + 4 * 5, vl); - _sum30 = __riscv_vle32_v_f32m1(pC + 4 * 6, vl); - _sum31 = __riscv_vle32_v_f32m1(pC + 4 * 7, vl); - _sum40 = __riscv_vle32_v_f32m1(pC + 4 * 8, vl); - _sum41 = __riscv_vle32_v_f32m1(pC + 4 * 9, vl); - _sum50 = __riscv_vle32_v_f32m1(pC + 4 * 10, vl); - _sum51 = __riscv_vle32_v_f32m1(pC + 4 * 11, vl); - _sum60 = __riscv_vle32_v_f32m1(pC + 4 * 12, vl); - _sum61 = __riscv_vle32_v_f32m1(pC + 4 * 13, vl); - _sum70 = __riscv_vle32_v_f32m1(pC + 4 * 14, vl); - _sum71 = __riscv_vle32_v_f32m1(pC + 4 * 15, vl); - pC += 64; - } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum20 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum30 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - _sum40 = __riscv_vfmv_v_f_f32m1(pC[4], vl); - _sum50 = __riscv_vfmv_v_f_f32m1(pC[5], vl); - _sum60 = __riscv_vfmv_v_f_f32m1(pC[6], vl); - _sum70 = __riscv_vfmv_v_f_f32m1(pC[7], vl); - _sum01 = _sum00; - _sum11 = _sum10; - _sum21 = _sum20; - _sum31 = _sum30; - _sum41 = _sum40; - _sum51 = _sum50; - _sum61 = _sum60; - _sum71 = _sum70; - pC += 8; - } - } - } - else - { - _sum00 = __riscv_vle32_v_f32m1(outptr, vl); - _sum01 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - _sum20 = __riscv_vle32_v_f32m1(outptr + 4 * 4, vl); - _sum21 = __riscv_vle32_v_f32m1(outptr + 4 * 5, vl); - _sum30 = __riscv_vle32_v_f32m1(outptr + 4 * 6, vl); - _sum31 = __riscv_vle32_v_f32m1(outptr + 4 * 7, vl); - _sum40 = __riscv_vle32_v_f32m1(outptr + 4 * 8, vl); - _sum41 = __riscv_vle32_v_f32m1(outptr + 4 * 9, vl); - _sum50 = __riscv_vle32_v_f32m1(outptr + 4 * 10, vl); - _sum51 = __riscv_vle32_v_f32m1(outptr + 4 * 11, vl); - _sum60 = __riscv_vle32_v_f32m1(outptr + 4 * 12, vl); - _sum61 = __riscv_vle32_v_f32m1(outptr + 4 * 13, vl); - _sum70 = __riscv_vle32_v_f32m1(outptr + 4 * 14, vl); - _sum71 = __riscv_vle32_v_f32m1(outptr + 4 * 15, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA0 = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - _sum40 = __riscv_vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); - _sum41 = __riscv_vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); - _sum50 = __riscv_vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); - _sum51 = __riscv_vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); - _sum60 = __riscv_vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); - _sum61 = __riscv_vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); - _sum70 = __riscv_vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); - _sum71 = __riscv_vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); - - pA += 8; - pB += 8; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 4, _sum40, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 5, _sum50, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 6, _sum60, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 7, _sum70, vl); - - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 4, _sum41, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 5, _sum51, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 6, _sum61, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 7, _sum71, vl); - - outptr0 += 32; - } - if (out_elempack == 1) - { - transpose8x8_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, _sum40, _sum41, _sum50, _sum51, _sum60, _sum61, _sum70, _sum71, vl); - - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2, _sum20, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum21, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3, _sum30, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum31, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum40, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum41, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 5, _sum50, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 5 + 4, _sum51, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 6, _sum60, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 6 + 4, _sum61, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 7, _sum70, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 7 + 4, _sum71, vl); - - outptr0 += 8; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum00, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 8, _sum40, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 9, _sum41, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 10, _sum50, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 11, _sum51, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 12, _sum60, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 13, _sum61, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 14, _sum70, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 15, _sum71, vl); - } - - outptr += 64; - } - for (; jj + 3 < max_jj; jj += 4) - { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - vfloat32m1_t _sum10; - vfloat32m1_t _sum11; - vfloat32m1_t _sum20; - vfloat32m1_t _sum21; - vfloat32m1_t _sum30; - vfloat32m1_t _sum31; - - if (k == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum10 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum11 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum20 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum21 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum30 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum31 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum20 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum21 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum30 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum31 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum10 = _sum00; - _sum11 = _sum01; - _sum20 = _sum00; - _sum21 = _sum01; - _sum30 = _sum00; - _sum31 = _sum01; - } - if (broadcast_type_C == 3) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(pC + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(pC + 4 * 3, vl); - _sum20 = __riscv_vle32_v_f32m1(pC + 4 * 4, vl); - _sum21 = __riscv_vle32_v_f32m1(pC + 4 * 5, vl); - _sum30 = __riscv_vle32_v_f32m1(pC + 4 * 6, vl); - _sum31 = __riscv_vle32_v_f32m1(pC + 4 * 7, vl); - pC += 32; - } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum20 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum30 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - _sum01 = _sum00; - _sum11 = _sum10; - _sum21 = _sum20; - _sum31 = _sum30; - pC += 4; - } - } - } - else - { - _sum00 = __riscv_vle32_v_f32m1(outptr, vl); - _sum01 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - _sum20 = __riscv_vle32_v_f32m1(outptr + 4 * 4, vl); - _sum21 = __riscv_vle32_v_f32m1(outptr + 4 * 5, vl); - _sum30 = __riscv_vle32_v_f32m1(outptr + 4 * 6, vl); - _sum31 = __riscv_vle32_v_f32m1(outptr + 4 * 7, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA0 = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - _sum20 = __riscv_vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); - _sum21 = __riscv_vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); - _sum30 = __riscv_vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); - _sum31 = __riscv_vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); - - pA += 8; - pB += 4; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); - - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); - - outptr0 += 16; - } - if (out_elempack == 1) - { - transpose8x4_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, vl); - - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 1, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum20, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 5, _sum21, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 6, _sum30, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 7, _sum31, vl); - - outptr0 += 4; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum00, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); - } - - outptr += 32; - } - for (; jj + 1 < max_jj; jj += 2) - { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - vfloat32m1_t _sum10; - vfloat32m1_t _sum11; - - if (k == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum10 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum11 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum10 = _sum00; - _sum11 = _sum01; - } - if (broadcast_type_C == 3) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(pC + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(pC + 4 * 3, vl); - pC += 16; - } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum01 = _sum00; - _sum11 = _sum10; - pC += 2; - } - } - } - else - { - _sum00 = __riscv_vle32_v_f32m1(outptr, vl); - _sum01 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum10 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum11 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA0 = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); - - pA += 8; - pB += 2; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum10, vl); - - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); - outptr0 += 8; - } - if (out_elempack == 1) - { - float sum0[8]; - float sum1[8]; - __riscv_vse32_v_f32m1(sum0, _sum00, vl); - __riscv_vse32_v_f32m1(sum0 + 4, _sum01, vl); - __riscv_vse32_v_f32m1(sum1, _sum10, vl); - __riscv_vse32_v_f32m1(sum1 + 4, _sum11, vl); - - outptr0[0] = sum0[0]; - outptr0[out_hstep] = sum0[1]; - outptr0[out_hstep * 2] = sum0[2]; - outptr0[out_hstep * 3] = sum0[3]; - outptr0[out_hstep * 4] = sum0[4]; - outptr0[out_hstep * 5] = sum0[5]; - outptr0[out_hstep * 6] = sum0[6]; - outptr0[out_hstep * 7] = sum0[7]; - - outptr0[1] = sum1[0]; - outptr0[out_hstep + 1] = sum1[1]; - outptr0[out_hstep * 2 + 1] = sum1[2]; - outptr0[out_hstep * 3 + 1] = sum1[3]; - outptr0[out_hstep * 4 + 1] = sum1[4]; - outptr0[out_hstep * 5 + 1] = sum1[5]; - outptr0[out_hstep * 6 + 1] = sum1[6]; - outptr0[out_hstep * 7 + 1] = sum1[7]; - outptr0 += 2; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum00, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); - } - - outptr += 16; - } - for (; jj < max_jj; jj += 1) - { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - - if (k == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - } - if (broadcast_type_C == 3) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - pC += 8; - } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = _sum00; - pC += 1; - } - } - } - else - { - _sum00 = __riscv_vle32_v_f32m1(outptr, vl); - _sum01 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA0 = __riscv_vle32_v_f32m1(pA, vl); - vfloat32m1_t _pA1 = __riscv_vle32_v_f32m1(pA + 4, vl); - - vfloat32m1_t _pB = __riscv_vfmv_v_f_f32m1(pB[0], vl); - - _sum00 = __riscv_vfmadd_vv_f32m1(_pA0, _pB, _sum00, vl); - _sum01 = __riscv_vfmadd_vv_f32m1(_pA1, _pB, _sum01, vl); - - pA += 8; - pB += 1; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); - outptr0 += 4; - } - if (out_elempack == 1) - { - float sum0[8]; - __riscv_vse32_v_f32m1(sum0, _sum00, vl); - __riscv_vse32_v_f32m1(sum0 + 4, _sum01, vl); - - outptr0[0] = sum0[0]; - outptr0[out_hstep * 1] = sum0[1]; - outptr0[out_hstep * 2] = sum0[2]; - outptr0[out_hstep * 3] = sum0[3]; - outptr0[out_hstep * 4] = sum0[4]; - outptr0[out_hstep * 5] = sum0[5]; - outptr0[out_hstep * 6] = sum0[6]; - outptr0[out_hstep * 7] = sum0[7]; - outptr0++; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum00, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum01, vl); - } - - outptr += 8; - } - - pAT += max_kk * 8; - } - for (; ii + 3 < max_ii; ii += 4) - { - float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; - - const float* pB = pBT; - - if (pC) - { - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const float*)CT_tile + i + ii; - } - if (broadcast_type_C == 4) - { - pC = (const float*)CT_tile + j; - } - } - - int jj = 0; - for (; jj + 11 < max_jj; jj += 12) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - vfloat32m1_t _sum2; - vfloat32m1_t _sum3; - vfloat32m1_t _sum4; - vfloat32m1_t _sum5; - vfloat32m1_t _sum6; - vfloat32m1_t _sum7; - vfloat32m1_t _sum8; - vfloat32m1_t _sum9; - vfloat32m1_t _suma; - vfloat32m1_t _sumb; - - if (k == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum4 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum5 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum6 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum7 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum8 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum9 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _suma = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sumb = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum4 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum5 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum6 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum7 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum8 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum9 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _suma = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sumb = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - _sum8 = _sum0; - _sum9 = _sum0; - _suma = _sum0; - _sumb = _sum0; - } - if (broadcast_type_C == 3) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum2 = __riscv_vle32_v_f32m1(pC + 8, vl); - _sum3 = __riscv_vle32_v_f32m1(pC + 12, vl); - _sum4 = __riscv_vle32_v_f32m1(pC + 16, vl); - _sum5 = __riscv_vle32_v_f32m1(pC + 20, vl); - _sum6 = __riscv_vle32_v_f32m1(pC + 24, vl); - _sum7 = __riscv_vle32_v_f32m1(pC + 28, vl); - _sum8 = __riscv_vle32_v_f32m1(pC + 32, vl); - _sum9 = __riscv_vle32_v_f32m1(pC + 36, vl); - _suma = __riscv_vle32_v_f32m1(pC + 40, vl); - _sumb = __riscv_vle32_v_f32m1(pC + 44, vl); - pC += 48; - } - if (broadcast_type_C == 4) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - _sum4 = __riscv_vfmv_v_f_f32m1(pC[4], vl); - _sum5 = __riscv_vfmv_v_f_f32m1(pC[5], vl); - _sum6 = __riscv_vfmv_v_f_f32m1(pC[6], vl); - _sum7 = __riscv_vfmv_v_f_f32m1(pC[7], vl); - _sum8 = __riscv_vfmv_v_f_f32m1(pC[8], vl); - _sum9 = __riscv_vfmv_v_f_f32m1(pC[9], vl); - _suma = __riscv_vfmv_v_f_f32m1(pC[10], vl); - _sumb = __riscv_vfmv_v_f_f32m1(pC[11], vl); - pC += 12; - } - } - } - else - { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum2 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - _sum4 = __riscv_vle32_v_f32m1(outptr + 4 * 4, vl); - _sum5 = __riscv_vle32_v_f32m1(outptr + 4 * 5, vl); - _sum6 = __riscv_vle32_v_f32m1(outptr + 4 * 6, vl); - _sum7 = __riscv_vle32_v_f32m1(outptr + 4 * 7, vl); - _sum8 = __riscv_vle32_v_f32m1(outptr + 4 * 8, vl); - _sum9 = __riscv_vle32_v_f32m1(outptr + 4 * 9, vl); - _suma = __riscv_vle32_v_f32m1(outptr + 4 * 10, vl); - _sumb = __riscv_vle32_v_f32m1(outptr + 4 * 11, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); - - _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); - _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); - _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); - _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); - _sum4 = __riscv_vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); - _sum5 = __riscv_vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); - _sum6 = __riscv_vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); - _sum7 = __riscv_vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); - _sum8 = __riscv_vfmadd_vf_f32m1(_pA, pB[8], _sum8, vl); - _sum9 = __riscv_vfmadd_vf_f32m1(_pA, pB[9], _sum9, vl); - _suma = __riscv_vfmadd_vf_f32m1(_pA, pB[10], _suma, vl); - _sumb = __riscv_vfmadd_vf_f32m1(_pA, pB[11], _sumb, vl); - - pA += 4; - pB += 12; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 5, _sum5, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 6, _sum6, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 7, _sum7, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 8, _sum8, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 9, _sum9, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 10, _suma, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 11, _sumb, vl); - outptr0 += 48; - } - if (out_elempack == 1) - { - transpose4x12_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb, vl); - - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + 8, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum3, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 8, _sum5, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2, _sum6, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum7, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2 + 8, _sum8, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3, _sum9, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _suma, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3 + 8, _sumb, vl); - outptr0 += 12; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 5, _sum5, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 6, _sum6, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 7, _sum7, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 8, _sum8, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 9, _sum9, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 10, _suma, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 11, _sumb, vl); - } - - outptr += 48; - } - for (; jj + 7 < max_jj; jj += 8) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - vfloat32m1_t _sum2; - vfloat32m1_t _sum3; - vfloat32m1_t _sum4; - vfloat32m1_t _sum5; - vfloat32m1_t _sum6; - vfloat32m1_t _sum7; - - if (k == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum4 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum5 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum6 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum7 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum4 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum5 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum6 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum7 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - _sum4 = _sum0; - _sum5 = _sum0; - _sum6 = _sum0; - _sum7 = _sum0; - } - if (broadcast_type_C == 3) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum2 = __riscv_vle32_v_f32m1(pC + 8, vl); - _sum3 = __riscv_vle32_v_f32m1(pC + 12, vl); - _sum4 = __riscv_vle32_v_f32m1(pC + 16, vl); - _sum5 = __riscv_vle32_v_f32m1(pC + 20, vl); - _sum6 = __riscv_vle32_v_f32m1(pC + 24, vl); - _sum7 = __riscv_vle32_v_f32m1(pC + 28, vl); - pC += 32; - } - if (broadcast_type_C == 4) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - _sum4 = __riscv_vfmv_v_f_f32m1(pC[4], vl); - _sum5 = __riscv_vfmv_v_f_f32m1(pC[5], vl); - _sum6 = __riscv_vfmv_v_f_f32m1(pC[6], vl); - _sum7 = __riscv_vfmv_v_f_f32m1(pC[7], vl); - pC += 8; - } - } - } - else - { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum2 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - _sum4 = __riscv_vle32_v_f32m1(outptr + 4 * 4, vl); - _sum5 = __riscv_vle32_v_f32m1(outptr + 4 * 5, vl); - _sum6 = __riscv_vle32_v_f32m1(outptr + 4 * 6, vl); - _sum7 = __riscv_vle32_v_f32m1(outptr + 4 * 7, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); - - _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); - _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); - _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); - _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); - _sum4 = __riscv_vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); - _sum5 = __riscv_vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); - _sum6 = __riscv_vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); - _sum7 = __riscv_vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); - - pA += 4; - pB += 8; - } - - if (k_end) - { - if (out_elempack == 4) - { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 5, _sum5, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 6, _sum6, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 7, _sum7, vl); - outptr0 += 32; - } - if (out_elempack == 1) - { - transpose4x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, vl); - - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 4, _sum3, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2, _sum4, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum5, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3, _sum6, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum7, vl); - outptr0 += 8; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 4, _sum4, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 5, _sum5, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 6, _sum6, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 7, _sum7, vl); - } - - outptr += 32; - } - for (; jj + 3 < max_jj; jj += 4) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - vfloat32m1_t _sum2; - vfloat32m1_t _sum3; - - if (k == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = _sum0; - _sum2 = _sum0; - _sum3 = _sum0; - } - if (broadcast_type_C == 3) + else + { + _sum0 = __riscv_vle32_v_f32m1(outptr, vl); + _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); + _sum4 = __riscv_vle32_v_f32m1(outptr + packn * 4, vl); + _sum5 = __riscv_vle32_v_f32m1(outptr + packn * 5, vl); + _sum6 = __riscv_vle32_v_f32m1(outptr + packn * 6, vl); + _sum7 = __riscv_vle32_v_f32m1(outptr + packn * 7, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = __riscv_vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = __riscv_vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = __riscv_vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = __riscv_vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + pA += packn; + pB += 8; + } + + if (k_end) + { + if (out_elempack == packn) { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum2 = __riscv_vle32_v_f32m1(pC + 8, vl); - _sum3 = __riscv_vle32_v_f32m1(pC + 12, vl); - pC += 16; + __riscv_vse32_v_f32m1(outptr0, _sum0, vl); + __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 4, _sum4, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 5, _sum5, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 6, _sum6, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 7, _sum7, vl); + outptr0 += packn * 8; } - if (broadcast_type_C == 4) + if (out_elempack == 1) { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); - _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); - pC += 4; + vfloat32m1x8_t _sum = __riscv_vcreate_v_f32m1x8(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + __riscv_vssseg8e32_v_f32m1x8(outptr0, out_hstep * sizeof(float), _sum, vl); + outptr0 += 8; } } + else + { + __riscv_vse32_v_f32m1(outptr, _sum0, vl); + __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); + __riscv_vse32_v_f32m1(outptr + packn * 4, _sum4, vl); + __riscv_vse32_v_f32m1(outptr + packn * 5, _sum5, vl); + __riscv_vse32_v_f32m1(outptr + packn * 6, _sum6, vl); + __riscv_vse32_v_f32m1(outptr + packn * 7, _sum7, vl); + } + + outptr += packn * 8; } - else + else if (packn == 4) { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + 4 * 1, vl); - _sum2 = __riscv_vle32_v_f32m1(outptr + 4 * 2, vl); - _sum3 = __riscv_vle32_v_f32m1(outptr + 4 * 3, vl); - } + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); + if (k == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); + _sum3 = __riscv_vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = __riscv_vle32_v_f32m1(pC, vl); + _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(pC + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(pC + packn * 3, vl); + pC += packn * 4; + } + if (broadcast_type_C == 4) + { + _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); + _sum1 = __riscv_vfmv_v_f_f32m1(pC[1], vl); + _sum2 = __riscv_vfmv_v_f_f32m1(pC[2], vl); + _sum3 = __riscv_vfmv_v_f_f32m1(pC[3], vl); + pC += 4; + } + } + } + else + { + _sum0 = __riscv_vle32_v_f32m1(outptr, vl); + _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); + _sum2 = __riscv_vle32_v_f32m1(outptr + packn * 2, vl); + _sum3 = __riscv_vle32_v_f32m1(outptr + packn * 3, vl); + } - _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); - _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); - _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); - _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); - pA += 4; - pB += 4; - } + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = __riscv_vle32_v_f32m1(pA, vl); + _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = __riscv_vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = __riscv_vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + pA += packn; + pB += 4; + } - if (k_end) - { - if (out_elempack == 4) + if (k_end) { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); - outptr0 += 16; + if (out_elempack == packn) + { + __riscv_vse32_v_f32m1(outptr0, _sum0, vl); + __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr0 + packn * 3, _sum3, vl); + outptr0 += packn * 4; + } + if (out_elempack == 1) + { + vfloat32m1x4_t _sum = __riscv_vcreate_v_f32m1x4(_sum0, _sum1, _sum2, _sum3); + __riscv_vssseg4e32_v_f32m1x4(outptr0, out_hstep * sizeof(float), _sum, vl); + outptr0 += 4; + } } - if (out_elempack == 1) + else { - transpose4x4_ps(_sum0, _sum1, _sum2, _sum3, vl); - - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 1, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep * 3, _sum3, vl); - outptr0 += 4; + __riscv_vse32_v_f32m1(outptr, _sum0, vl); + __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); + __riscv_vse32_v_f32m1(outptr + packn * 2, _sum2, vl); + __riscv_vse32_v_f32m1(outptr + packn * 3, _sum3, vl); } + + outptr += packn * 4; } else { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); - __riscv_vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + NCNN_LOGE("unsupported vector length"); } - - outptr += 16; } for (; jj + 1 < max_jj; jj += 2) { @@ -2569,8 +857,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons if (broadcast_type_C == 3) { _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + 4, vl); - pC += 8; + _sum1 = __riscv_vle32_v_f32m1(pC + packn, vl); + pC += packn * 2; } if (broadcast_type_C == 4) { @@ -2583,7 +871,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons else { _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + 4, vl); + _sum1 = __riscv_vle32_v_f32m1(outptr + packn, vl); } const float* pA = pAT; @@ -2595,43 +883,32 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons _sum0 = __riscv_vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); _sum1 = __riscv_vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); - pA += 4; + pA += packn; pB += 2; } if (k_end) { - if (out_elempack == 4) + if (out_elempack == packn) { __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - outptr0 += 8; + __riscv_vse32_v_f32m1(outptr0 + packn, _sum1, vl); + outptr0 += packn * 2; } if (out_elempack == 1) { - float sum0[4]; - float sum1[4]; - __riscv_vse32_v_f32m1(sum0, _sum0, vl); - __riscv_vse32_v_f32m1(sum1, _sum1, vl); - - outptr0[0] = sum0[0]; - outptr0[out_hstep] = sum0[1]; - outptr0[out_hstep * 2] = sum0[2]; - outptr0[out_hstep * 3] = sum0[3]; - outptr0[1] = sum1[0]; - outptr0[out_hstep + 1] = sum1[1]; - outptr0[out_hstep * 2 + 1] = sum1[2]; - outptr0[out_hstep * 3 + 1] = sum1[3]; + vfloat32m1x2_t _sum = __riscv_vcreate_v_f32m1x2(_sum0, _sum1); + __riscv_vssseg2e32_v_f32m1x2(outptr0, out_hstep * sizeof(float), _sum, vl); outptr0 += 2; } } else { __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl); + __riscv_vse32_v_f32m1(outptr + packn, _sum1, vl); } - outptr += 8; + outptr += packn * 2; } for (; jj < max_jj; jj += 1) { @@ -2654,7 +931,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons if (broadcast_type_C == 3) { _sum0 = __riscv_vle32_v_f32m1(pC, vl); - pC += 4; + pC += packn; } if (broadcast_type_C == 4) { @@ -2677,26 +954,20 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons _sum0 = __riscv_vfmadd_vv_f32m1(_pA, _pB, _sum0, vl); - pA += 4; + pA += packn; pB += 1; } if (k_end) { - if (out_elempack == 4) + if (out_elempack == packn) { __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - outptr0 += 4; + outptr0 += packn; } if (out_elempack == 1) { - float sum0[4]; - __riscv_vse32_v_f32m1(sum0, _sum0, vl); - - outptr0[0] = sum0[0]; - outptr0[out_hstep] = sum0[1]; - outptr0[out_hstep * 2] = sum0[2]; - outptr0[out_hstep * 3] = sum0[3]; + __riscv_vsse32_v_f32m1(outptr0, out_hstep * sizeof(float), _sum0, vl); outptr0++; } } @@ -2705,10 +976,10 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons __riscv_vse32_v_f32m1(outptr, _sum0, vl); } - outptr += 4; + outptr += packn; } - pAT += max_kk * 4; + pAT += max_kk * packn; } #endif // __riscv_vector for (; ii + 1 < max_ii; ii += 2) @@ -2731,218 +1002,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons int jj = 0; #if __riscv_vector - for (; jj + 11 < max_jj; jj += 12) - { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - vfloat32m1_t _sum02; - vfloat32m1_t _sum10; - vfloat32m1_t _sum11; - vfloat32m1_t _sum12; - - if (k == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum02 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum10 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum11 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum12 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum02 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum12 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum02 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum12 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - } - if (broadcast_type_C == 3) - { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pC, vl); - vfloat32m1x2_t _s1 = __riscv_vlseg2e32_v_f32m1x2(pC + 8, vl); - vfloat32m1x2_t _s2 = __riscv_vlseg2e32_v_f32m1x2(pC + 16, vl); - _sum00 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); - _sum10 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); - _sum01 = __riscv_vget_v_f32m1x2_f32m1(_s1, 0); - _sum11 = __riscv_vget_v_f32m1x2_f32m1(_s1, 1); - _sum02 = __riscv_vget_v_f32m1x2_f32m1(_s2, 0); - _sum12 = __riscv_vget_v_f32m1x2_f32m1(_s2, 1); - pC += 24; - } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum02 = __riscv_vle32_v_f32m1(pC + 8, vl); - _sum10 = _sum00; - _sum11 = _sum01; - _sum12 = _sum02; - pC += 12; - } - } - } - else - { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(outptr, vl); - vfloat32m1x2_t _s1 = __riscv_vlseg2e32_v_f32m1x2(outptr + 8, vl); - vfloat32m1x2_t _s2 = __riscv_vlseg2e32_v_f32m1x2(outptr + 16, vl); - _sum00 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); - _sum10 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); - _sum01 = __riscv_vget_v_f32m1x2_f32m1(_s1, 0); - _sum11 = __riscv_vget_v_f32m1x2_f32m1(_s1, 1); - _sum02 = __riscv_vget_v_f32m1x2_f32m1(_s2, 0); - _sum12 = __riscv_vget_v_f32m1x2_f32m1(_s2, 1); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pB0 = __riscv_vle32_v_f32m1(pB, vl); - vfloat32m1_t _pB1 = __riscv_vle32_v_f32m1(pB + 4, vl); - vfloat32m1_t _pB2 = __riscv_vle32_v_f32m1(pB + 8, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pB0, pA[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pB1, pA[0], _sum01, vl); - _sum02 = __riscv_vfmadd_vf_f32m1(_pB2, pA[0], _sum02, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pB0, pA[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pB1, pA[1], _sum11, vl); - _sum12 = __riscv_vfmadd_vf_f32m1(_pB2, pA[1], _sum12, vl); - - pA += 2; - pB += 12; - } - - if (k_end) - { - // if (out_elempack == 1) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + 8, _sum02, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 8, _sum12, vl); - outptr0 += 12; - } - } - else - { - __riscv_vsseg2e32_v_f32m1x2(outptr, __riscv_vcreate_v_f32m1x2(_sum00, _sum10), vl); - __riscv_vsseg2e32_v_f32m1x2(outptr + 8, __riscv_vcreate_v_f32m1x2(_sum01, _sum11), vl); - __riscv_vsseg2e32_v_f32m1x2(outptr + 16, __riscv_vcreate_v_f32m1x2(_sum02, _sum12), vl); - } - - outptr += 24; - } - for (; jj + 7 < max_jj; jj += 8) - { - vfloat32m1_t _sum00; - vfloat32m1_t _sum01; - vfloat32m1_t _sum10; - vfloat32m1_t _sum11; - - if (k == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum01 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum10 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum11 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum00 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum01 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum10 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - _sum11 = __riscv_vfmv_v_f_f32m1(pC[1], vl); - } - if (broadcast_type_C == 3) - { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pC, vl); - vfloat32m1x2_t _s1 = __riscv_vlseg2e32_v_f32m1x2(pC + 8, vl); - _sum00 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); - _sum10 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); - _sum01 = __riscv_vget_v_f32m1x2_f32m1(_s1, 0); - _sum11 = __riscv_vget_v_f32m1x2_f32m1(_s1, 1); - pC += 16; - } - if (broadcast_type_C == 4) - { - _sum00 = __riscv_vle32_v_f32m1(pC, vl); - _sum01 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum10 = _sum00; - _sum11 = _sum01; - pC += 8; - } - } - } - else - { - vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(outptr, vl); - vfloat32m1x2_t _s1 = __riscv_vlseg2e32_v_f32m1x2(outptr + 8, vl); - _sum00 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); - _sum10 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); - _sum01 = __riscv_vget_v_f32m1x2_f32m1(_s1, 0); - _sum11 = __riscv_vget_v_f32m1x2_f32m1(_s1, 1); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pB0 = __riscv_vle32_v_f32m1(pB, vl); - vfloat32m1_t _pB1 = __riscv_vle32_v_f32m1(pB + 4, vl); - - _sum00 = __riscv_vfmadd_vf_f32m1(_pB0, pA[0], _sum00, vl); - _sum01 = __riscv_vfmadd_vf_f32m1(_pB1, pA[0], _sum01, vl); - _sum10 = __riscv_vfmadd_vf_f32m1(_pB0, pA[1], _sum10, vl); - _sum11 = __riscv_vfmadd_vf_f32m1(_pB1, pA[1], _sum11, vl); - pA += 2; - pB += 8; - } - - if (k_end) - { - // if (out_elempack == 1) - { - __riscv_vse32_v_f32m1(outptr0, _sum00, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum01, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); - __riscv_vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); - outptr0 += 8; - } - } - else - { - __riscv_vsseg2e32_v_f32m1x2(outptr, __riscv_vcreate_v_f32m1x2(_sum00, _sum10), vl); - __riscv_vsseg2e32_v_f32m1x2(outptr + 8, __riscv_vcreate_v_f32m1x2(_sum01, _sum11), vl); - } - - outptr += 16; - } - for (; jj + 3 < max_jj; jj += 4) + for (; jj + (packn - 1) < max_jj; jj += packn) { vfloat32m1_t _sum0; vfloat32m1_t _sum1; @@ -2969,13 +1029,13 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons vfloat32m1x2_t _s0 = __riscv_vlseg2e32_v_f32m1x2(pC, vl); _sum0 = __riscv_vget_v_f32m1x2_f32m1(_s0, 0); _sum1 = __riscv_vget_v_f32m1x2_f32m1(_s0, 1); - pC += 8; + pC += packn * 2; } if (broadcast_type_C == 4) { _sum0 = __riscv_vle32_v_f32m1(pC, vl); _sum1 = _sum0; - pC += 4; + pC += packn; } } } @@ -2996,7 +1056,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons _sum1 = __riscv_vfmadd_vf_f32m1(_pB, pA[1], _sum1, vl); pA += 2; - pB += 4; + pB += packn; } if (k_end) @@ -3005,7 +1065,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons { __riscv_vse32_v_f32m1(outptr0, _sum0, vl); __riscv_vse32_v_f32m1(outptr0 + out_hstep, _sum1, vl); - outptr0 += 4; + outptr0 += packn; } } else @@ -3013,7 +1073,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons __riscv_vsseg2e32_v_f32m1x2(outptr, __riscv_vcreate_v_f32m1x2(_sum0, _sum1), vl); } - outptr += 8; + outptr += packn * 2; } #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) @@ -3198,143 +1258,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons int jj = 0; #if __riscv_vector - for (; jj + 11 < max_jj; jj += 12) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - vfloat32m1_t _sum2; - - if (k == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum2 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum2 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + 4, vl); - _sum2 = __riscv_vle32_v_f32m1(pC + 8, vl); - pC += 12; - } - } - } - else - { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + 4, vl); - _sum2 = __riscv_vle32_v_f32m1(outptr + 8, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pB0 = __riscv_vle32_v_f32m1(pB, vl); - vfloat32m1_t _pB1 = __riscv_vle32_v_f32m1(pB + 4, vl); - vfloat32m1_t _pB2 = __riscv_vle32_v_f32m1(pB + 8, vl); - - vfloat32m1_t _pA0 = __riscv_vfmv_v_f_f32m1(pA[0], vl); - - _sum0 = __riscv_vfmadd_vv_f32m1(_pA0, _pB0, _sum0, vl); - _sum1 = __riscv_vfmadd_vv_f32m1(_pA0, _pB1, _sum1, vl); - _sum2 = __riscv_vfmadd_vv_f32m1(_pA0, _pB2, _sum2, vl); - - pA += 1; - pB += 12; - } - - if (k_end) - { - // if (out_elempack == 1) - { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr0 + 8, _sum2, vl); - outptr0 += 12; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl); - __riscv_vse32_v_f32m1(outptr + 8, _sum2, vl); - } - - outptr += 12; - } - for (; jj + 7 < max_jj; jj += 8) - { - vfloat32m1_t _sum0; - vfloat32m1_t _sum1; - - if (k == 0) - { - _sum0 = __riscv_vfmv_v_f_f32m1(0.f, vl); - _sum1 = __riscv_vfmv_v_f_f32m1(0.f, vl); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _sum0 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - _sum1 = __riscv_vfmv_v_f_f32m1(pC[0], vl); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - _sum0 = __riscv_vle32_v_f32m1(pC, vl); - _sum1 = __riscv_vle32_v_f32m1(pC + 4, vl); - pC += 8; - } - } - } - else - { - _sum0 = __riscv_vle32_v_f32m1(outptr, vl); - _sum1 = __riscv_vle32_v_f32m1(outptr + 4, vl); - } - - const float* pA = pAT; - int kk = 0; - for (; kk < max_kk; kk += 1) - { - vfloat32m1_t _pB0 = __riscv_vle32_v_f32m1(pB, vl); - vfloat32m1_t _pB1 = __riscv_vle32_v_f32m1(pB + 4, vl); - - vfloat32m1_t _pA0 = __riscv_vfmv_v_f_f32m1(pA[0], vl); - _sum0 = __riscv_vfmadd_vv_f32m1(_pA0, _pB0, _sum0, vl); - _sum1 = __riscv_vfmadd_vv_f32m1(_pA0, _pB1, _sum1, vl); - - pA += 1; - pB += 8; - } - - if (k_end) - { - // if (out_elempack == 1) - { - __riscv_vse32_v_f32m1(outptr0, _sum0, vl); - __riscv_vse32_v_f32m1(outptr0 + 4, _sum1, vl); - outptr0 += 8; - } - } - else - { - __riscv_vse32_v_f32m1(outptr, _sum0, vl); - __riscv_vse32_v_f32m1(outptr + 4, _sum1, vl); - } - - outptr += 8; - } - for (; jj + 3 < max_jj; jj += 4) + for (; jj + (packn - 1) < max_jj; jj += packn) { vfloat32m1_t _sum; @@ -3351,7 +1275,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons if (broadcast_type_C == 3 || broadcast_type_C == 4) { _sum = __riscv_vle32_v_f32m1(pC, vl); - pC += 4; + pC += packn; } } } @@ -3370,7 +1294,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons _sum = __riscv_vfmadd_vv_f32m1(_pA, _pB, _sum, vl); pA += 1; - pB += 4; + pB += packn; } if (k_end) @@ -3378,7 +1302,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons // if (out_elempack == 1) { __riscv_vse32_v_f32m1(outptr0, _sum, vl); - outptr0 += 4; + outptr0 += packn; } } else @@ -3386,7 +1310,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons __riscv_vse32_v_f32m1(outptr, _sum, vl); } - outptr += 4; + outptr += packn; } #endif // __riscv_vector for (; jj + 1 < max_jj; jj += 2) @@ -3513,8 +1437,14 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c int tile_size = (int)sqrtf((float)l2_cache_size / 3 / sizeof(float)); - TILE_M = std::max(8, tile_size / 8 * 8); - TILE_N = std::max(4, tile_size / 4 * 4); +#if __riscv_vector + const int packn = csrr_vlenb() / 4; +#else + const int packn = 4; +#endif + + TILE_M = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn, tile_size / packn * packn); TILE_K = std::max(8, tile_size / 8 * 8); if (K > 0) @@ -3525,8 +1455,8 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c if (nn_K == 1) { tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); - TILE_M = std::max(8, tile_size / 8 * 8); - TILE_N = std::max(4, tile_size / 4 * 4); + TILE_M = std::max(packn, tile_size / packn * packn); + TILE_N = std::max(packn, tile_size / packn * packn); } } @@ -3535,29 +1465,29 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c if (M > 0) { int nn_M = (M + TILE_M - 1) / TILE_M; - TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + (packn - 1)) / packn * packn); } if (N > 0) { int nn_N = (N + TILE_N - 1) / TILE_N; - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + (packn - 1)) / packn * packn); } if (nT > 1) { - TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + (packn - 1)) / packn * packn); } // always take constant TILE_M/N/K value when provided if (constant_TILE_M > 0) { - TILE_M = (constant_TILE_M + 7) / 8 * 8; + TILE_M = (constant_TILE_M + (packn - 1)) / packn * packn; } if (constant_TILE_N > 0) { - TILE_N = (constant_TILE_N + 3) / 4 * 4; + TILE_N = (constant_TILE_N + (packn - 1)) / packn * packn; } if (constant_TILE_K > 0) @@ -3566,7 +1496,7 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c } } -static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; @@ -3608,11 +1538,11 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i if (transB) { - pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); } else { - transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); } } @@ -3641,7 +1571,7 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i if (broadcast_type_C == 3) { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; @@ -3660,21 +1590,21 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i { if (transA) { - transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); } else { - pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); } } bool k_end = !output_transpose && k + TILE_K >= K; - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); } if (output_transpose) { - transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); } } } @@ -3682,7 +1612,7 @@ static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, i return 0; } -static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; @@ -3718,11 +1648,11 @@ static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blo if (transB) { - pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); } else { - transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk); } } @@ -3746,7 +1676,7 @@ static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blo if (broadcast_type_C == 3) { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; @@ -3761,12 +1691,12 @@ static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blo Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); bool k_end = !output_transpose && k + TILE_K >= K; - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); } if (output_transpose) { - transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); } } } @@ -3774,7 +1704,7 @@ static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blo return 0; } -static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; @@ -3815,7 +1745,7 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo if (broadcast_type_C == 3) { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; @@ -3834,22 +1764,22 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo { if (transA) { - transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); } else { - pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk); } } bool k_end = !output_transpose && k + TILE_K >= K; - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); } if (output_transpose) { - transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); } } } @@ -3857,7 +1787,7 @@ static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blo return 0; } -static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -3890,7 +1820,7 @@ static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top if (broadcast_type_C == 3) { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); } const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; @@ -3907,12 +1837,12 @@ static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top bool k_end = !output_transpose && k + TILE_K >= K; - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end); } if (output_transpose) { - transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj); } } } @@ -3967,11 +1897,11 @@ int Gemm_riscv::create_pipeline(const Option& opt) if (transA) { - transpose_pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk, vl); + transpose_pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk); } else { - pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk, vl); + pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk); } } } @@ -4008,11 +1938,11 @@ int Gemm_riscv::create_pipeline(const Option& opt) if (transB) { - pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk, vl); + pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk); } else { - transpose_pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk, vl); + transpose_pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk); } } } @@ -4026,9 +1956,11 @@ int Gemm_riscv::create_pipeline(const Option& opt) CT_data = C_data; #if __riscv_vector + const int packn = csrr_vlenb() / 4; + if (constant_broadcast_type_C == 3 && opt.use_packing_layout) { - int C_elempack = constantM % 4 == 0 ? 4 : 1; + int C_elempack = constantM % packn == 0 ? packn : 1; convert_packing(C_data, CT_data, C_elempack, opt); } #endif // __riscv_vector @@ -4173,12 +2105,16 @@ int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& } } +#if __riscv_vector + const int packn = csrr_vlenb() / 4; +#endif + int out_elempack = 1; #if __riscv_vector if (opt.use_packing_layout) { int outh = output_transpose ? N : M; - out_elempack = outh % 4 == 0 ? 4 : 1; + out_elempack = outh % packn == 0 ? packn : 1; } #endif // __riscv_vector if (output_elempack) @@ -4214,23 +2150,23 @@ int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& int ret = 0; if (constantA && constantB) { - ret = gemm_AT_BT_riscv(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + ret = gemm_AT_BT_riscv(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); } else if (constantA) { const Mat& B = bottom_blobs[0]; - ret = gemm_AT_riscv(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + ret = gemm_AT_riscv(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); } else if (constantB) { const Mat& A = bottom_blobs[0]; - ret = gemm_BT_riscv(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + ret = gemm_BT_riscv(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); } else { const Mat& A = bottom_blobs[0]; const Mat& B = bottom_blobs[1]; - ret = gemm_riscv(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + ret = gemm_riscv(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); } if (ret != 0) return ret; diff --git a/src/layer/riscv/gemm_riscv.h b/src/layer/riscv/gemm_riscv.h index 6bca092fb1f..967a9ee12c9 100644 --- a/src/layer/riscv/gemm_riscv.h +++ b/src/layer/riscv/gemm_riscv.h @@ -30,9 +30,8 @@ class Gemm_riscv : public Gemm virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - // public: +public: int nT; - size_t vl; Mat AT_data; Mat BT_data; Mat CT_data;