Skip to content

Commit

Permalink
RVV: Refine riscv gemm fp32 (#5303)
Browse files Browse the repository at this point in the history
* replace storexxx to vsseg2e32_v_f32m1

* refine transpose

---------

Co-authored-by: Xinyu302 <[email protected]>
  • Loading branch information
Xinyu302 and Xinyu302 authored Jan 30, 2024
1 parent 10fd242 commit 7ac4268
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 318 deletions.
62 changes: 20 additions & 42 deletions src/layer/riscv/gemm_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,10 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max
vfloat32m1_t _r6h = vle32_v_f32m1(p6 + 4, vl);
vfloat32m1_t _r7l = vle32_v_f32m1(p7, vl);
vfloat32m1_t _r7h = vle32_v_f32m1(p7 + 4, vl);
transpose8x8_ps(_r0l, _r0h, _r1l, _r1h, _r2l, _r2h, _r3l, _r3h, _r4l, _r4h, _r5l, _r5h, _r6l, _r6h, _r7l, _r7h, vl);
vse32_v_f32m1(pp, _r0l, vl);
vse32_v_f32m1(pp + 4, _r0h, vl);
vse32_v_f32m1(pp + 8, _r1l, vl);
vse32_v_f32m1(pp + 12, _r1h, vl);
vse32_v_f32m1(pp + 8 * 2, _r2l, vl);
vse32_v_f32m1(pp + 8 * 2 + 4, _r2h, vl);
vse32_v_f32m1(pp + 8 * 3, _r3l, vl);
vse32_v_f32m1(pp + 8 * 3 + 4, _r3h, vl);
vse32_v_f32m1(pp + 8 * 4, _r4l, vl);
vse32_v_f32m1(pp + 8 * 4 + 4, _r4h, vl);
vse32_v_f32m1(pp + 8 * 5, _r5l, vl);
vse32_v_f32m1(pp + 8 * 5 + 4, _r5h, vl);
vse32_v_f32m1(pp + 8 * 6, _r6l, vl);
vse32_v_f32m1(pp + 8 * 6 + 4, _r6h, vl);
vse32_v_f32m1(pp + 8 * 7, _r7l, vl);
vse32_v_f32m1(pp + 8 * 7 + 4, _r7h, vl);

vsseg8e32_v_f32m1(pp, _r0l, _r1l, _r2l, _r3l, _r4l, _r5l, _r6l, _r7l, vl);
vsseg8e32_v_f32m1(pp + 32, _r0h, _r1h, _r2h, _r3h, _r4h, _r5h, _r6h, _r7h, vl);

pp += 64;
p0 += 8;
p1 += 8;
Expand Down Expand Up @@ -175,7 +162,7 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max
vfloat32m1_t v1 = vle32_v_f32m1(p1, vl);
vfloat32m1_t v2 = vle32_v_f32m1(p2, vl);
vfloat32m1_t v3 = vle32_v_f32m1(p3, vl);
store_float_v4(v0, v1, v2, v3, pp, vl);
vsseg4e32_v_f32m1(pp, v0, v1, v2, v3, vl);
pp += 16;
p0 += 4;
p1 += 4;
Expand Down Expand Up @@ -210,7 +197,7 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max
{
vfloat32m1_t v0 = vle32_v_f32m1(p0, vl);
vfloat32m1_t v1 = vle32_v_f32m1(p1, vl);
store_float_v2(v0, v1, pp, vl);
vsseg2e32_v_f32m1(pp, v0, v1, vl);
pp += 8;
p0 += 4;
p1 += 4;
Expand Down Expand Up @@ -353,7 +340,7 @@ static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int
{
vfloat32m1_t v0 = vle32_v_f32m1(p0, vl);
vfloat32m1_t v1 = vle32_v_f32m1(p0 + 4, vl);
store_float_v2(v0, v1, pp, vl);
vsseg2e32_v_f32m1(pp, v0, v1, vl);
pp += 8;
p0 += A_hstep * 4;
}
Expand Down Expand Up @@ -562,17 +549,8 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max
vfloat32m1_t _r6 = vle32_v_f32m1(p6, vl);
vfloat32m1_t _r7 = vle32_v_f32m1(p7, vl);

transpose4x4_ps(_r0, _r1, _r2, _r3, vl);
transpose4x4_ps(_r4, _r5, _r6, _r7, vl);
vsseg8e32_v_f32m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl);

vse32_v_f32m1(pp, _r0, vl);
vse32_v_f32m1(pp + 4, _r4, vl);
vse32_v_f32m1(pp + 4 * 2, _r1, vl);
vse32_v_f32m1(pp + 4 * 3, _r5, vl);
vse32_v_f32m1(pp + 4 * 4, _r2, vl);
vse32_v_f32m1(pp + 4 * 5, _r6, vl);
vse32_v_f32m1(pp + 4 * 6, _r3, vl);
vse32_v_f32m1(pp + 4 * 7, _r7, vl);
pp += 32;
p0 += 4;
p1 += 4;
Expand Down Expand Up @@ -632,7 +610,7 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max
vfloat32m1_t v1 = vle32_v_f32m1(p1, vl);
vfloat32m1_t v2 = vle32_v_f32m1(p2, vl);
vfloat32m1_t v3 = vle32_v_f32m1(p3, vl);
store_float_v4(v0, v1, v2, v3, pp, vl);
vsseg4e32_v_f32m1(pp, v0, v1, v2, v3, vl);
pp += 16;
p0 += 4;
p1 += 4;
Expand Down Expand Up @@ -667,7 +645,7 @@ static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max
{
vfloat32m1_t v0 = vle32_v_f32m1(p0, vl);
vfloat32m1_t v1 = vle32_v_f32m1(p1, vl);
store_float_v2(v0, v1, pp, vl);
vsseg2e32_v_f32m1(pp, v0, v1, vl);
pp += 8;
p0 += 4;
p1 += 4;
Expand Down Expand Up @@ -865,7 +843,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int
{
vfloat32m1_t v0 = vle32_v_f32m1(p0, vl);
vfloat32m1_t v1 = vle32_v_f32m1(p0 + 4, vl);
store_float_v2(v0, v1, pp, vl);
vsseg2e32_v_f32m1(pp, v0, v1, vl);
pp += 8;
p0 += B_hstep * 4;
}
Expand Down Expand Up @@ -937,12 +915,12 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i,
vfloat32m1_t v1 = vle32_v_f32m1(pp + 8, vl);
vfloat32m1_t v2 = vle32_v_f32m1(pp + 16, vl);
vfloat32m1_t v3 = vle32_v_f32m1(pp + 24, vl);
store_float_v4(v0, v1, v2, v3, p0, vl);
vsseg4e32_v_f32m1(p0, v0, v1, v2, v3, vl);
v0 = vle32_v_f32m1(pp + 4, vl);
v1 = vle32_v_f32m1(pp + 12, vl);
v2 = vle32_v_f32m1(pp + 20, vl);
v3 = vle32_v_f32m1(pp + 28, vl);
store_float_v4(v0, v1, v2, v3, p0 + 16, vl);
vsseg4e32_v_f32m1(p0 + 16, v0, v1, v2, v3, vl);
pp += 32;
p0 += out_hstep * 4;
}
Expand Down Expand Up @@ -974,7 +952,7 @@ static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i,
vfloat32m1_t v1 = vle32_v_f32m1(pp + 4, vl);
vfloat32m1_t v2 = vle32_v_f32m1(pp + 8, vl);
vfloat32m1_t v3 = vle32_v_f32m1(pp + 12, vl);
store_float_v4(v0, v1, v2, v3, p0, vl);
vsseg4e32_v_f32m1(p0, v0, v1, v2, v3, vl);
pp += 16;
p0 += out_hstep * 4;
}
Expand Down Expand Up @@ -2887,9 +2865,9 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons
}
else
{
store_float_v2(_sum00, _sum10, outptr, vl);
store_float_v2(_sum01, _sum11, outptr + 8, vl);
store_float_v2(_sum02, _sum12, outptr + 16, vl);
vsseg2e32_v_f32m1(outptr, _sum00, _sum10, vl);
vsseg2e32_v_f32m1(outptr + 8, _sum01, _sum11, vl);
vsseg2e32_v_f32m1(outptr + 16, _sum02, _sum12, vl);
}

outptr += 24;
Expand Down Expand Up @@ -2974,8 +2952,8 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons
}
else
{
store_float_v2(_sum00, _sum10, outptr, vl);
store_float_v2(_sum01, _sum11, outptr + 8, vl);
vsseg2e32_v_f32m1(outptr, _sum00, _sum10, vl);
vsseg2e32_v_f32m1(outptr + 8, _sum01, _sum11, vl);
}

outptr += 16;
Expand Down Expand Up @@ -3048,7 +3026,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons
}
else
{
store_float_v2(_sum0, _sum1, outptr, vl);
vsseg2e32_v_f32m1(outptr, _sum0, _sum1, vl);
}

outptr += 8;
Expand Down
Loading

0 comments on commit 7ac4268

Please sign in to comment.