From 21c42839a5e348d70ccfe88151784029fa253660 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 23 Apr 2024 19:02:30 +0800 Subject: [PATCH] wip --- src/layer/arm/gru_arm.cpp | 14 +- src/layer/arm/gru_arm.h | 2 - src/layer/arm/gru_arm_asimdhp.cpp | 1714 ++++++++++++---------------- src/layer/arm/lstm_arm.cpp | 14 +- src/layer/arm/lstm_arm.h | 2 - src/layer/arm/lstm_arm_asimdhp.cpp | 796 +++++-------- src/layer/arm/rnn_arm_asimdhp.cpp | 1 + 7 files changed, 1058 insertions(+), 1485 deletions(-) diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index db26e9babfd..c50564685ac 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -1330,12 +1330,7 @@ int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blob, top_blob, opt); - else - return forward_fp16s(bottom_blob, top_blob, opt); - } + return forward_fp16s(bottom_blob, top_blob, opt); #endif #if NCNN_BF16 @@ -1440,12 +1435,7 @@ int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blobs, top_blobs, opt); - else - return forward_fp16s(bottom_blobs, top_blobs, opt); - } + return forward_fp16s(bottom_blobs, top_blobs, opt); #endif #if NCNN_BF16 diff --git a/src/layer/arm/gru_arm.h b/src/layer/arm/gru_arm.h index 62e63bc852a..b44a1f38be7 100644 --- a/src/layer/arm/gru_arm.h +++ b/src/layer/arm/gru_arm.h @@ -36,8 +36,6 @@ class GRU_arm : public GRU int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - int forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; - int forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif #if NCNN_BF16 int create_pipeline_bf16s(const Option& opt); diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index 0d9317daf2c..2fda651092e 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -23,7 +23,7 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -55,177 +55,253 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M const __fp16* weight_xc_RUN = weight_xc.row(q / 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4); - float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); - float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - float32x4_t _sum4 = vdupq_n_f32(0.f); - float32x4_t _sum5 = vdupq_n_f32(0.f); - float32x4_t _sum6 = vdupq_n_f32(0.f); + float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); + float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); int i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); - float32x4_t _weight_xc_R_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); - float32x4_t _weight_xc_U_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); - float32x4_t _weight_xc_R_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 16)); - float32x4_t _weight_xc_U_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 20)); - float32x4_t _weight_xc_R_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 24)); - float32x4_t _weight_xc_U_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 28)); - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4h}, [%0], #8 \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_RUN), + "=w"(_RU), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(x), + "1"(weight_xc_RUN), + "2"(_RU), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); + float16x8_t _w0 = vld1q_f16(weight_xc_RUN); + float16x8_t _w1 = vld1q_f16(weight_xc_RUN + 8); + float16x8_t _w2 = vld1q_f16(weight_xc_RUN + 16); + float16x8_t _w3 = vld1q_f16(weight_xc_RUN + 24); + _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + x += 4; weight_xc_RUN += 32; +#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = x[i]; + __fp16 xi = *x++; - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); - _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); - _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); + float16x8_t _xi = vdupq_n_f16(xi); + float16x8_t _weight_xc_RU = vld1q_f16(weight_xc_RUN); + _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); weight_xc_RUN += 8; } + const float* hidden_ptr = hidden_state; + i = 0; for (; i + 3 < num_output; i += 4) { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); - float32x4_t _weight_hc_R_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); - float32x4_t _weight_hc_U_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); - float32x4_t _weight_hc_R_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 16)); - float32x4_t _weight_hc_U_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 20)); - float32x4_t _weight_hc_R_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 24)); - float32x4_t _weight_hc_U_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 28)); - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4s}, [%0], #16 \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "fcvtn v4.4h, v4.4s \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_RUN), + "=w"(_RU), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(hidden_ptr), + "1"(weight_hc_RUN), + "2"(_RU), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + float16x8_t _w0 = vld1q_f16(weight_hc_RUN); + float16x8_t _w1 = vld1q_f16(weight_hc_RUN + 8); + float16x8_t _w2 = vld1q_f16(weight_hc_RUN + 16); + float16x8_t _w3 = vld1q_f16(weight_hc_RUN + 24); + _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + hidden_ptr += 4; weight_hc_RUN += 32; +#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = hidden_state[i]; + float h_cont = *hidden_ptr++; - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); - _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); - _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); + float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); + float16x8_t _weight_hc_RU = vld1q_f16(weight_hc_RUN); + _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); weight_hc_RUN += 8; } - _gru_R = vaddq_f32(_gru_R, _sum1); - _gru_U = vaddq_f32(_gru_U, _sum2); - _sum3 = vaddq_f32(_sum3, _sum5); - _sum4 = vaddq_f32(_sum4, _sum6); - _gru_R = vaddq_f32(_gru_R, _sum3); - _gru_U = vaddq_f32(_gru_U, _sum4); + _RU = vaddq_f16(_RU, _sum1); + _sum2 = vaddq_f16(_sum2, _sum3); + _RU = vaddq_f16(_RU, _sum2); // sigmoid(R) // sigmoid(U) - _gru_R = sigmoid_ps(_gru_R); - _gru_U = sigmoid_ps(_gru_U); + float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); + float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); + + x -= size; + hidden_ptr = hidden_state; // gate new - float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); + float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); + float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); i = 0; for (; i + 3 < num_output; i += 4) { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - float32x4_t _weight_hc_N_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); - float32x4_t _weight_hc_N_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); - float32x4_t _weight_hc_N_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4s}, [%0], #16 \n" + "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" + "fcvtn v4.4h, v4.4s \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_RUN), + "=w"(_gru_N), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6) + : "0"(hidden_ptr), + "1"(weight_hc_RUN), + "2"(_gru_N), + "3"(_sum4), + "4"(_sum5), + "5"(_sum6) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + float16x4_t _w0 = vld1_f16(weight_hc_RUN); + float16x4_t _w1 = vld1_f16(weight_hc_RUN + 4); + float16x4_t _w2 = vld1_f16(weight_hc_RUN + 8); + float16x4_t _w3 = vld1_f16(weight_hc_RUN + 12); + _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); + _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); + _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); + _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); + hidden_ptr += 4; weight_hc_RUN += 16; +#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = hidden_state[i]; + float h_cont = *hidden_ptr++; - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); + float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); + float16x4_t _weight_hc_N = vld1_f16(weight_hc_RUN); + _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); weight_hc_RUN += 4; } - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); + _gru_N = vadd_f16(_gru_N, _sum4); + _sum5 = vadd_f16(_sum5, _sum6); + _gru_N = vadd_f16(_gru_N, _sum5); - _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); + _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); + _sum4 = vdup_n_f16((__fp16)0.f); + _sum5 = vdup_n_f16((__fp16)0.f); + _sum6 = vdup_n_f16((__fp16)0.f); i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - float32x4_t _weight_xc_N_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); - float32x4_t _weight_xc_N_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); - float32x4_t _weight_xc_N_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); - - weight_xc_RUN += 16; - } - for (; i < size; i++) - { - __fp16 xi = x[i]; - - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); - - weight_xc_RUN += 4; - } - - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); - - // tanh(N) - _gru_N = tanh_ps(_gru_N); - +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4h}, [%0], #8 \n" + "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_RUN), + "=w"(_gru_N), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6) + : "0"(x), + "1"(weight_xc_RUN), + "2"(_gru_N), + "3"(_sum4), + "4"(_sum5), + "5"(_sum6) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); + float16x4_t _w0 = vld1_f16(weight_xc_RUN); + float16x4_t _w1 = vld1_f16(weight_xc_RUN + 4); + float16x4_t _w2 = vld1_f16(weight_xc_RUN + 8); + float16x4_t _w3 = vld1_f16(weight_xc_RUN + 12); + _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); + _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); + _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); + _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); + + x += 4; + weight_xc_RUN += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; i < size; i++) + { + __fp16 xi = *x++; + + float16x4_t _xi = vdup_n_f16(xi); + float16x4_t _weight_xc_N = vld1_f16(weight_xc_RUN); + _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); + + weight_xc_RUN += 4; + } + + _gru_N = vadd_f16(_gru_N, _sum4); + _sum5 = vadd_f16(_sum5, _sum6); + _gru_N = vadd_f16(_gru_N, _sum5); + + // tanh(N) + float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); + float* gates_data = gates.row(q / 4); - vst1q_f32(gates_data, _gru_U); - vst1q_f32(gates_data + 4, _gru_N); + vst1q_f32(gates_data, _U32); + vst1q_f32(gates_data + 4, _N32); } #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) @@ -238,64 +314,64 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M const __fp16* weight_xc_RUN = weight_xc.row(q / 4 + q % 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4 + q % 4); - float R = (float)bias_c_RUBNWN[0]; - float U = (float)bias_c_RUBNWN[1]; + __fp16 R = bias_c_RUBNWN[0]; + __fp16 U = bias_c_RUBNWN[1]; for (int i = 0; i < size; i++) { - float xi = (float)x[i]; + __fp16 xi = x[i]; - R += (float)weight_xc_RUN[0] * xi; - U += (float)weight_xc_RUN[1] * xi; + R += weight_xc_RUN[0] * xi; + U += weight_xc_RUN[1] * xi; weight_xc_RUN += 2; } for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + __fp16 h_cont = (__fp16)hidden_state[i]; - R += (float)weight_hc_RUN[0] * h_cont; - U += (float)weight_hc_RUN[1] * h_cont; + R += weight_hc_RUN[0] * h_cont; + U += weight_hc_RUN[1] * h_cont; weight_hc_RUN += 2; } // sigmoid(R) // sigmoid(U) - R = 1.f / (1.f + expf(-R)); - U = 1.f / (1.f + expf(-U)); + float R32 = 1.f / (1.f + expf((float)-R)); + float U32 = 1.f / (1.f + expf((float)-U)); // gate new - float N = (float)bias_c_RUBNWN[2]; + __fp16 N = bias_c_RUBNWN[2]; for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + __fp16 h_cont = (__fp16)hidden_state[i]; - N += (float)weight_hc_RUN[0] * h_cont; + N += weight_hc_RUN[0] * h_cont; weight_hc_RUN += 1; } - N = (float)bias_c_RUBNWN[3] + R * N; + N = bias_c_RUBNWN[3] + (__fp16)R32 * N; for (int i = 0; i < size; i++) { - float xi = (float)x[i]; + __fp16 xi = x[i]; - N += (float)weight_xc_RUN[0] * xi; + N += weight_xc_RUN[0] * xi; weight_xc_RUN += 1; } // tanh(N) - N = tanhf(N); + float N32 = tanhf((float)N); float* gates_data = gates.row(q / 4 + q % 4); - gates_data[0] = U; - gates_data[1] = N; + gates_data[0] = U32; + gates_data[1] = N32; } // h_t := (1 - update) .* new + update .* h_{t-1} @@ -338,8 +414,11 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M return 0; } -static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { + if (opt.use_fp16_arithmetic) + return gru_fp16sa(bottom_blob, top_blob, reverse, weight_xc, bias_c, weight_hc, hidden_state, opt); + int size = bottom_blob.w; int T = bottom_blob.h; @@ -370,253 +449,177 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const const __fp16* weight_xc_RUN = weight_xc.row(q / 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4); - float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); - float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); + float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); + float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + float32x4_t _sum4 = vdupq_n_f32(0.f); + float32x4_t _sum5 = vdupq_n_f32(0.f); + float32x4_t _sum6 = vdupq_n_f32(0.f); int i = 0; for (; i + 3 < size; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4h}, [%0], #8 \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(x), - "1"(weight_xc_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); - float16x8_t _w0 = vld1q_f16(weight_xc_RUN); - float16x8_t _w1 = vld1q_f16(weight_xc_RUN + 8); - float16x8_t _w2 = vld1q_f16(weight_xc_RUN + 16); - float16x8_t _w3 = vld1q_f16(weight_xc_RUN + 24); - _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); + float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); + float32x4_t _weight_xc_R_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); + float32x4_t _weight_xc_U_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); + float32x4_t _weight_xc_R_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 16)); + float32x4_t _weight_xc_U_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 20)); + float32x4_t _weight_xc_R_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 24)); + float32x4_t _weight_xc_U_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 28)); + _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); + _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); + _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); + _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); - x += 4; weight_xc_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = *x++; + __fp16 xi = x[i]; - float16x8_t _xi = vdupq_n_f16(xi); - float16x8_t _weight_xc_RU = vld1q_f16(weight_xc_RUN); - _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); + _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); + _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); weight_xc_RUN += 8; } - const float* hidden_ptr = hidden_state; - i = 0; for (; i + 3 < num_output; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4s}, [%0], #16 \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - "fcvtn v4.4h, v4.4s \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(hidden_ptr), - "1"(weight_hc_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); - float16x8_t _w0 = vld1q_f16(weight_hc_RUN); - float16x8_t _w1 = vld1q_f16(weight_hc_RUN + 8); - float16x8_t _w2 = vld1q_f16(weight_hc_RUN + 16); - float16x8_t _w3 = vld1q_f16(weight_hc_RUN + 24); - _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); + float32x4_t _weight_hc_R_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); + float32x4_t _weight_hc_U_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); + float32x4_t _weight_hc_R_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 16)); + float32x4_t _weight_hc_U_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 20)); + float32x4_t _weight_hc_R_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 24)); + float32x4_t _weight_hc_U_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 28)); + _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); + _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); + _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); + _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); - hidden_ptr += 4; weight_hc_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = *hidden_ptr++; + float h_cont = hidden_state[i]; - float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); - float16x8_t _weight_hc_RU = vld1q_f16(weight_hc_RUN); - _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); + _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); + _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); weight_hc_RUN += 8; } - _RU = vaddq_f16(_RU, _sum1); - _sum2 = vaddq_f16(_sum2, _sum3); - _RU = vaddq_f16(_RU, _sum2); + _gru_R = vaddq_f32(_gru_R, _sum1); + _gru_U = vaddq_f32(_gru_U, _sum2); + _sum3 = vaddq_f32(_sum3, _sum5); + _sum4 = vaddq_f32(_sum4, _sum6); + _gru_R = vaddq_f32(_gru_R, _sum3); + _gru_U = vaddq_f32(_gru_U, _sum4); // sigmoid(R) // sigmoid(U) - float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); - float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); - - x -= size; - hidden_ptr = hidden_state; + _gru_R = sigmoid_ps(_gru_R); + _gru_U = sigmoid_ps(_gru_U); // gate new - float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); - float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); - - i = 0; + float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); + + i = 0; for (; i + 3 < num_output; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4s}, [%0], #16 \n" - "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" - "fcvtn v4.4h, v4.4s \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(hidden_ptr), - "1"(weight_hc_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); - float16x4_t _w0 = vld1_f16(weight_hc_RUN); - float16x4_t _w1 = vld1_f16(weight_hc_RUN + 4); - float16x4_t _w2 = vld1_f16(weight_hc_RUN + 8); - float16x4_t _w3 = vld1_f16(weight_hc_RUN + 12); - _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + float32x4_t _weight_hc_N_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); + float32x4_t _weight_hc_N_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); + float32x4_t _weight_hc_N_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); + _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); - hidden_ptr += 4; weight_hc_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = *hidden_ptr++; + float h_cont = hidden_state[i]; - float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); - float16x4_t _weight_hc_N = vld1_f16(weight_hc_RUN); - _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); weight_hc_RUN += 4; } - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); + _gru_N = vaddq_f32(_gru_N, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _gru_N = vaddq_f32(_gru_N, _sum2); - _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); - _sum4 = vdup_n_f16((__fp16)0.f); - _sum5 = vdup_n_f16((__fp16)0.f); - _sum6 = vdup_n_f16((__fp16)0.f); + _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); i = 0; for (; i + 3 < size; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4h}, [%0], #8 \n" - "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(x), - "1"(weight_xc_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); - float16x4_t _w0 = vld1_f16(weight_xc_RUN); - float16x4_t _w1 = vld1_f16(weight_xc_RUN + 4); - float16x4_t _w2 = vld1_f16(weight_xc_RUN + 8); - float16x4_t _w3 = vld1_f16(weight_xc_RUN + 12); - _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); + float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + float32x4_t _weight_xc_N_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); + float32x4_t _weight_xc_N_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); + float32x4_t _weight_xc_N_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); + _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); - x += 4; weight_xc_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = *x++; + __fp16 xi = x[i]; - float16x4_t _xi = vdup_n_f16(xi); - float16x4_t _weight_xc_N = vld1_f16(weight_xc_RUN); - _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); weight_xc_RUN += 4; } - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); + _gru_N = vaddq_f32(_gru_N, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _gru_N = vaddq_f32(_gru_N, _sum2); // tanh(N) - float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); + _gru_N = tanh_ps(_gru_N); float* gates_data = gates.row(q / 4); - vst1q_f32(gates_data, _U32); - vst1q_f32(gates_data + 4, _N32); + vst1q_f32(gates_data, _gru_U); + vst1q_f32(gates_data + 4, _gru_N); } #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) @@ -629,64 +632,64 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const const __fp16* weight_xc_RUN = weight_xc.row(q / 4 + q % 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4 + q % 4); - __fp16 R = bias_c_RUBNWN[0]; - __fp16 U = bias_c_RUBNWN[1]; + float R = (float)bias_c_RUBNWN[0]; + float U = (float)bias_c_RUBNWN[1]; for (int i = 0; i < size; i++) { - __fp16 xi = x[i]; + float xi = (float)x[i]; - R += weight_xc_RUN[0] * xi; - U += weight_xc_RUN[1] * xi; + R += (float)weight_xc_RUN[0] * xi; + U += (float)weight_xc_RUN[1] * xi; weight_xc_RUN += 2; } for (int i = 0; i < num_output; i++) { - __fp16 h_cont = (__fp16)hidden_state[i]; + float h_cont = hidden_state[i]; - R += weight_hc_RUN[0] * h_cont; - U += weight_hc_RUN[1] * h_cont; + R += (float)weight_hc_RUN[0] * h_cont; + U += (float)weight_hc_RUN[1] * h_cont; weight_hc_RUN += 2; } // sigmoid(R) // sigmoid(U) - float R32 = 1.f / (1.f + expf((float)-R)); - float U32 = 1.f / (1.f + expf((float)-U)); + R = 1.f / (1.f + expf(-R)); + U = 1.f / (1.f + expf(-U)); // gate new - __fp16 N = bias_c_RUBNWN[2]; + float N = (float)bias_c_RUBNWN[2]; for (int i = 0; i < num_output; i++) { - __fp16 h_cont = (__fp16)hidden_state[i]; + float h_cont = hidden_state[i]; - N += weight_hc_RUN[0] * h_cont; + N += (float)weight_hc_RUN[0] * h_cont; weight_hc_RUN += 1; } - N = bias_c_RUBNWN[3] + (__fp16)R32 * N; + N = (float)bias_c_RUBNWN[3] + R * N; for (int i = 0; i < size; i++) { - __fp16 xi = x[i]; + float xi = (float)x[i]; - N += weight_xc_RUN[0] * xi; + N += (float)weight_xc_RUN[0] * xi; weight_xc_RUN += 1; } // tanh(N) - float N32 = tanhf((float)N); + N = tanhf(N); float* gates_data = gates.row(q / 4 + q % 4); - gates_data[0] = U32; - gates_data[1] = N32; + gates_data[0] = U; + gates_data[1] = N; } // h_t := (1 - update) .* new + update .* h_{t-1} @@ -730,7 +733,7 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const } #if NCNN_INT8 -static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -762,224 +765,334 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4); const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); + const __fp16* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); + const __fp16* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); - float32x4_t _descale_xc_R = vld1q_f32(weight_xc_int8_descales_RUN); - float32x4_t _descale_xc_U = vld1q_f32(weight_xc_int8_descales_RUN + 4); - float32x4_t _descale_hc_R = vld1q_f32(weight_hc_int8_descales_RUN); - float32x4_t _descale_hc_U = vld1q_f32(weight_hc_int8_descales_RUN + 4); + float16x8_t _descale_xc_RU = vld1q_f16(weight_xc_int8_descales_RUN); + float16x8_t _descale_hc_RU = vld1q_f16(weight_hc_int8_descales_RUN); - float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); - float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - float32x4_t _sum4 = vdupq_n_f32(0.f); - float32x4_t _sum5 = vdupq_n_f32(0.f); - float32x4_t _sum6 = vdupq_n_f32(0.f); + float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); + float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); int i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v6.16b, v7.16b}, [%1], #32 \n" + "ld1 {v4.4h}, [%0], #8 \n" + "sxtl v0.8h, v6.8b \n" + "sxtl2 v1.8h, v6.16b \n" + "sxtl v2.8h, v7.8b \n" + "sxtl2 v3.8h, v7.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v1.8h, v1.8h \n" + "scvtf v2.8h, v2.8h \n" + "scvtf v3.8h, v3.8h \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v1.8h, v1.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "fmul v3.8h, v3.8h, %12.8h \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_int8_RUN), + "=w"(_RU), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(x), + "1"(weight_xc_int8_RUN), + "2"(_RU), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "w"(_descale_xc_RU) + : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); int8x16_t _weight_xc_RU01 = vld1q_s8(weight_xc_int8_RUN); int8x16_t _weight_xc_RU23 = vld1q_s8(weight_xc_int8_RUN + 16); - int16x8_t _weight_xc_RU0 = vmovl_s8(vget_low_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU1 = vmovl_s8(vget_high_s8(_weight_xc_RU01)); - int16x8_t _weight_xc_RU2 = vmovl_s8(vget_low_s8(_weight_xc_RU23)); - int16x8_t _weight_xc_RU3 = vmovl_s8(vget_high_s8(_weight_xc_RU23)); + float16x8_t _w0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_RU01))), _descale_xc_RU); + float16x8_t _w1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_RU01))), _descale_xc_RU); + float16x8_t _w2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_RU23))), _descale_xc_RU); + float16x8_t _w3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_RU23))), _descale_xc_RU); - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU0))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU0))), _descale_xc_U); - float32x4_t _weight_xc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU1))), _descale_xc_R); - float32x4_t _weight_xc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU1))), _descale_xc_U); - float32x4_t _weight_xc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU2))), _descale_xc_R); - float32x4_t _weight_xc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU2))), _descale_xc_U); - float32x4_t _weight_xc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU3))), _descale_xc_R); - float32x4_t _weight_xc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU3))), _descale_xc_U); - - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); + _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + x += 4; weight_xc_int8_RUN += 32; +#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = x[i]; + __fp16 xi = *x++; - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float16x8_t _xi = vdupq_n_f16(xi); - int16x8_t _weight_xc_RU = vmovl_s8(vld1_s8(weight_xc_int8_RUN)); - float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU))), _descale_xc_R); - float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU))), _descale_xc_U); + float16x8_t _weight_xc_RU = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))), _descale_xc_RU); - _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); - _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); + _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); weight_xc_int8_RUN += 8; } + const float* hidden_ptr = hidden_state; + i = 0; for (; i + 3 < num_output; i += 4) { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "ld1 {v4.4s}, [%0], #16 \n" + "sxtl v0.8h, v6.8b \n" + "sxtl2 v1.8h, v6.16b \n" + "sxtl v2.8h, v7.8b \n" + "sxtl2 v3.8h, v7.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v1.8h, v1.8h \n" + "scvtf v2.8h, v2.8h \n" + "scvtf v3.8h, v3.8h \n" + "fcvtn v4.4h, v4.4s \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v1.8h, v1.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "fmul v3.8h, v3.8h, %12.8h \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_int8_RUN), + "=w"(_RU), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(hidden_ptr), + "1"(weight_hc_int8_RUN), + "2"(_RU), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "w"(_descale_hc_RU) + : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); int8x16_t _weight_hc_RU01 = vld1q_s8(weight_hc_int8_RUN); int8x16_t _weight_hc_RU23 = vld1q_s8(weight_hc_int8_RUN + 16); - int16x8_t _weight_hc_RU0 = vmovl_s8(vget_low_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU1 = vmovl_s8(vget_high_s8(_weight_hc_RU01)); - int16x8_t _weight_hc_RU2 = vmovl_s8(vget_low_s8(_weight_hc_RU23)); - int16x8_t _weight_hc_RU3 = vmovl_s8(vget_high_s8(_weight_hc_RU23)); - - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU0))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU0))), _descale_hc_U); - float32x4_t _weight_hc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU1))), _descale_hc_R); - float32x4_t _weight_hc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU1))), _descale_hc_U); - float32x4_t _weight_hc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU2))), _descale_hc_R); - float32x4_t _weight_hc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU2))), _descale_hc_U); - float32x4_t _weight_hc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU3))), _descale_hc_R); - float32x4_t _weight_hc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU3))), _descale_hc_U); + float16x8_t _w0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_RU01))), _descale_hc_RU); + float16x8_t _w1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_RU01))), _descale_hc_RU); + float16x8_t _w2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_RU23))), _descale_hc_RU); + float16x8_t _w3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_RU23))), _descale_hc_RU); - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); + _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + hidden_ptr += 4; weight_hc_int8_RUN += 32; +#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = hidden_state[i]; + float h_cont = *hidden_ptr++; - float32x4_t _h_cont = vdupq_n_f32(h_cont); + float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); - int16x8_t _weight_hc_RU = vmovl_s8(vld1_s8(weight_hc_int8_RUN)); - float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU))), _descale_hc_R); - float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU))), _descale_hc_U); + float16x8_t _weight_hc_RU = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))), _descale_hc_RU); - _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); - _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); + _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); weight_hc_int8_RUN += 8; } - _gru_R = vaddq_f32(_gru_R, _sum1); - _gru_U = vaddq_f32(_gru_U, _sum2); - _sum3 = vaddq_f32(_sum3, _sum5); - _sum4 = vaddq_f32(_sum4, _sum6); - _gru_R = vaddq_f32(_gru_R, _sum3); - _gru_U = vaddq_f32(_gru_U, _sum4); + _RU = vaddq_f16(_RU, _sum1); + _sum2 = vaddq_f16(_sum2, _sum3); + _RU = vaddq_f16(_RU, _sum2); // sigmoid(R) // sigmoid(U) - _gru_R = sigmoid_ps(_gru_R); - _gru_U = sigmoid_ps(_gru_U); + float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); + float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); + + x -= size; + hidden_ptr = hidden_state; // gate new - float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); + float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); + float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); - float32x4_t _descale_xc_N = vld1q_f32(weight_xc_int8_descales_RUN + 8); - float32x4_t _descale_hc_N = vld1q_f32(weight_hc_int8_descales_RUN + 8); + float16x4_t _descale_xc_N = vld1_f16(weight_xc_int8_descales_RUN + 8); + float16x4_t _descale_hc_N = vld1_f16(weight_hc_int8_descales_RUN + 8); + float16x8_t _descale_xc_NN = vcombine_f16(_descale_xc_N, _descale_xc_N); + float16x8_t _descale_hc_NN = vcombine_f16(_descale_hc_N, _descale_hc_N); i = 0; for (; i + 3 < num_output; i += 4) { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v5.16b}, [%1], #16 \n" + "ld1 {v4.4s}, [%0], #16 \n" + "sxtl v0.8h, v5.8b \n" + "sxtl2 v2.8h, v5.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v2.8h, v2.8h \n" + "fcvtn v4.4h, v4.4s \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "mov v1.d[0], v0.d[1] \n" + "mov v3.d[0], v2.d[1] \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_int8_RUN), + "=w"(_gru_N), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6) + : "0"(hidden_ptr), + "1"(weight_hc_int8_RUN), + "2"(_gru_N), + "3"(_sum4), + "4"(_sum5), + "5"(_sum6), + "w"(_descale_hc_NN) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - int16x8_t _weight_hc_N01 = vmovl_s8(vget_low_s8(_weight_hc_N0123)); - int16x8_t _weight_hc_N23 = vmovl_s8(vget_high_s8(_weight_hc_N0123)); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N01))), _descale_hc_N); - float32x4_t _weight_hc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N23))), _descale_hc_N); - float32x4_t _weight_hc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N23))), _descale_hc_N); + float16x8_t _weight_hc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))), _descale_hc_NN); + float16x8_t _weight_hc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))), _descale_hc_NN); - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); + float16x4_t _w0 = vget_low_f16(_weight_hc_N01); + float16x4_t _w1 = vget_high_f16(_weight_hc_N01); + float16x4_t _w2 = vget_low_f16(_weight_hc_N23); + float16x4_t _w3 = vget_high_f16(_weight_hc_N23); + + _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); + _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); + _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); + _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); + hidden_ptr += 4; weight_hc_int8_RUN += 16; +#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = hidden_state[i]; + float h_cont = *hidden_ptr++; - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N); - _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); + float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); + float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN)))), _descale_hc_N); + _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); weight_hc_int8_RUN += 4; } - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); + _gru_N = vadd_f16(_gru_N, _sum4); + _sum5 = vadd_f16(_sum5, _sum6); + _gru_N = vadd_f16(_gru_N, _sum5); - _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); + _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); + _sum4 = vdup_n_f16((__fp16)0.f); + _sum5 = vdup_n_f16((__fp16)0.f); + _sum6 = vdup_n_f16((__fp16)0.f); i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v5.16b}, [%1], #16 \n" + "ld1 {v4.4h}, [%0], #8 \n" + "sxtl v0.8h, v5.8b \n" + "sxtl2 v2.8h, v5.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v2.8h, v2.8h \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "mov v1.d[0], v0.d[1] \n" + "mov v3.d[0], v2.d[1] \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_int8_RUN), + "=w"(_gru_N), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6) + : "0"(x), + "1"(weight_xc_int8_RUN), + "2"(_gru_N), + "3"(_sum4), + "4"(_sum5), + "5"(_sum6), + "w"(_descale_xc_NN) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - int16x8_t _weight_xc_N01 = vmovl_s8(vget_low_s8(_weight_xc_N0123)); - int16x8_t _weight_xc_N23 = vmovl_s8(vget_high_s8(_weight_xc_N0123)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N01))), _descale_xc_N); - float32x4_t _weight_xc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N23))), _descale_xc_N); - float32x4_t _weight_xc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N23))), _descale_xc_N); + float16x8_t _weight_xc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))), _descale_xc_NN); + float16x8_t _weight_xc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))), _descale_xc_NN); - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); + float16x4_t _w0 = vget_low_f16(_weight_xc_N01); + float16x4_t _w1 = vget_high_f16(_weight_xc_N01); + float16x4_t _w2 = vget_low_f16(_weight_xc_N23); + float16x4_t _w3 = vget_high_f16(_weight_xc_N23); + + _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); + _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); + _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); + _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); + x += 4; weight_xc_int8_RUN += 16; +#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = x[i]; + __fp16 xi = *x++; - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N); - _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); + float16x4_t _xi = vdup_n_f16(xi); + float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN)))), _descale_xc_N); + _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); weight_xc_int8_RUN += 4; } - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); + _gru_N = vadd_f16(_gru_N, _sum4); + _sum5 = vadd_f16(_sum5, _sum6); + _gru_N = vadd_f16(_gru_N, _sum5); // tanh(N) - _gru_N = tanh_ps(_gru_N); + float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); float* gates_data = gates.row(q / 4); - vst1q_f32(gates_data, _gru_U); - vst1q_f32(gates_data + 4, _gru_N); + vst1q_f32(gates_data, _U32); + vst1q_f32(gates_data + 4, _N32); } #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) @@ -991,23 +1104,23 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4 + q % 4); const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4 + q % 4); - const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); - const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); + const __fp16* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); + const __fp16* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); - const float descale_xc_R = weight_xc_int8_descales_RUN[0]; - const float descale_xc_U = weight_xc_int8_descales_RUN[1]; - const float descale_xc_N = weight_xc_int8_descales_RUN[2]; + const __fp16 descale_xc_R = weight_xc_int8_descales_RUN[0]; + const __fp16 descale_xc_U = weight_xc_int8_descales_RUN[1]; + const __fp16 descale_xc_N = weight_xc_int8_descales_RUN[2]; - const float descale_hc_R = weight_hc_int8_descales_RUN[0]; - const float descale_hc_U = weight_hc_int8_descales_RUN[1]; - const float descale_hc_N = weight_hc_int8_descales_RUN[2]; + const __fp16 descale_hc_R = weight_hc_int8_descales_RUN[0]; + const __fp16 descale_hc_U = weight_hc_int8_descales_RUN[1]; + const __fp16 descale_hc_N = weight_hc_int8_descales_RUN[2]; - float R = (float)bias_c_RUBNWN[0]; - float U = (float)bias_c_RUBNWN[1]; + __fp16 R = bias_c_RUBNWN[0]; + __fp16 U = bias_c_RUBNWN[1]; for (int i = 0; i < size; i++) { - float xi = (float)x[i]; + __fp16 xi = x[i]; R += weight_xc_int8_RUN[0] * descale_xc_R * xi; U += weight_xc_int8_RUN[1] * descale_xc_U * xi; @@ -1017,7 +1130,7 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + __fp16 h_cont = (__fp16)hidden_state[i]; R += weight_hc_int8_RUN[0] * descale_hc_R * h_cont; U += weight_hc_int8_RUN[1] * descale_hc_U * h_cont; @@ -1027,26 +1140,26 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co // sigmoid(R) // sigmoid(U) - R = 1.f / (1.f + expf(-R)); - U = 1.f / (1.f + expf(-U)); + float R32 = 1.f / (1.f + expf((float)-R)); + float U32 = 1.f / (1.f + expf((float)-U)); // gate new - float N = (float)bias_c_RUBNWN[2]; + __fp16 N = bias_c_RUBNWN[2]; for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + __fp16 h_cont = (__fp16)hidden_state[i]; N += weight_hc_int8_RUN[0] * descale_hc_N * h_cont; weight_hc_int8_RUN += 1; } - N = (float)bias_c_RUBNWN[3] + R * N; + N = bias_c_RUBNWN[3] + (__fp16)R32 * N; for (int i = 0; i < size; i++) { - float xi = (float)x[i]; + __fp16 xi = x[i]; N += weight_xc_int8_RUN[0] * descale_xc_N * xi; @@ -1054,12 +1167,12 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co } // tanh(N) - N = tanhf(N); + float N32 = tanhf((float)N); float* gates_data = gates.row(q / 4 + q % 4); - gates_data[0] = U; - gates_data[1] = N; + gates_data[0] = U32; + gates_data[1] = N32; } // h_t := (1 - update) .* new + update .* h_{t-1} @@ -1102,8 +1215,11 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co return 0; } -static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) { + if (opt.use_fp16_arithmetic) + return gru_fp16sa_int8(bottom_blob, top_blob, reverse, weight_xc_int8, weight_xc_int8_descales, bias_c, weight_hc_int8, weight_hc_int8_descales, hidden_state, opt); + int size = bottom_blob.w; int T = bottom_blob.h; @@ -1134,334 +1250,224 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4); const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4); - const __fp16* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); - const __fp16* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); + const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4); + const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4); - float16x8_t _descale_xc_RU = vld1q_f16(weight_xc_int8_descales_RUN); - float16x8_t _descale_hc_RU = vld1q_f16(weight_hc_int8_descales_RUN); + float32x4_t _descale_xc_R = vld1q_f32(weight_xc_int8_descales_RUN); + float32x4_t _descale_xc_U = vld1q_f32(weight_xc_int8_descales_RUN + 4); + float32x4_t _descale_hc_R = vld1q_f32(weight_hc_int8_descales_RUN); + float32x4_t _descale_hc_U = vld1q_f32(weight_hc_int8_descales_RUN + 4); - float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); - float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); + float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); + float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + float32x4_t _sum4 = vdupq_n_f32(0.f); + float32x4_t _sum5 = vdupq_n_f32(0.f); + float32x4_t _sum6 = vdupq_n_f32(0.f); int i = 0; for (; i + 3 < size; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v6.16b, v7.16b}, [%1], #32 \n" - "ld1 {v4.4h}, [%0], #8 \n" - "sxtl v0.8h, v6.8b \n" - "sxtl2 v1.8h, v6.16b \n" - "sxtl v2.8h, v7.8b \n" - "sxtl2 v3.8h, v7.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v1.8h, v1.8h \n" - "scvtf v2.8h, v2.8h \n" - "scvtf v3.8h, v3.8h \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v1.8h, v1.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "fmul v3.8h, v3.8h, %12.8h \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_int8_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(x), - "1"(weight_xc_int8_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3), - "w"(_descale_xc_RU) - : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); int8x16_t _weight_xc_RU01 = vld1q_s8(weight_xc_int8_RUN); int8x16_t _weight_xc_RU23 = vld1q_s8(weight_xc_int8_RUN + 16); - float16x8_t _w0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_RU01))), _descale_xc_RU); - float16x8_t _w1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_RU01))), _descale_xc_RU); - float16x8_t _w2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_RU23))), _descale_xc_RU); - float16x8_t _w3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_RU23))), _descale_xc_RU); + int16x8_t _weight_xc_RU0 = vmovl_s8(vget_low_s8(_weight_xc_RU01)); + int16x8_t _weight_xc_RU1 = vmovl_s8(vget_high_s8(_weight_xc_RU01)); + int16x8_t _weight_xc_RU2 = vmovl_s8(vget_low_s8(_weight_xc_RU23)); + int16x8_t _weight_xc_RU3 = vmovl_s8(vget_high_s8(_weight_xc_RU23)); - _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU0))), _descale_xc_R); + float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU0))), _descale_xc_U); + float32x4_t _weight_xc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU1))), _descale_xc_R); + float32x4_t _weight_xc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU1))), _descale_xc_U); + float32x4_t _weight_xc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU2))), _descale_xc_R); + float32x4_t _weight_xc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU2))), _descale_xc_U); + float32x4_t _weight_xc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU3))), _descale_xc_R); + float32x4_t _weight_xc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU3))), _descale_xc_U); + + _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); + _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); + _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); + _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); - x += 4; weight_xc_int8_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = *x++; + __fp16 xi = x[i]; - float16x8_t _xi = vdupq_n_f16(xi); + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float16x8_t _weight_xc_RU = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))), _descale_xc_RU); + int16x8_t _weight_xc_RU = vmovl_s8(vld1_s8(weight_xc_int8_RUN)); + float32x4_t _weight_xc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_RU))), _descale_xc_R); + float32x4_t _weight_xc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_RU))), _descale_xc_U); - _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); + _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); + _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); weight_xc_int8_RUN += 8; } - const float* hidden_ptr = hidden_state; - i = 0; for (; i + 3 < num_output; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v6.8h, v7.8h}, [%1], #32 \n" - "ld1 {v4.4s}, [%0], #16 \n" - "sxtl v0.8h, v6.8b \n" - "sxtl2 v1.8h, v6.16b \n" - "sxtl v2.8h, v7.8b \n" - "sxtl2 v3.8h, v7.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v1.8h, v1.8h \n" - "scvtf v2.8h, v2.8h \n" - "scvtf v3.8h, v3.8h \n" - "fcvtn v4.4h, v4.4s \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v1.8h, v1.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "fmul v3.8h, v3.8h, %12.8h \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_int8_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(hidden_ptr), - "1"(weight_hc_int8_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3), - "w"(_descale_hc_RU) - : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); int8x16_t _weight_hc_RU01 = vld1q_s8(weight_hc_int8_RUN); int8x16_t _weight_hc_RU23 = vld1q_s8(weight_hc_int8_RUN + 16); - float16x8_t _w0 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_RU01))), _descale_hc_RU); - float16x8_t _w1 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_RU01))), _descale_hc_RU); - float16x8_t _w2 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_RU23))), _descale_hc_RU); - float16x8_t _w3 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_RU23))), _descale_hc_RU); + int16x8_t _weight_hc_RU0 = vmovl_s8(vget_low_s8(_weight_hc_RU01)); + int16x8_t _weight_hc_RU1 = vmovl_s8(vget_high_s8(_weight_hc_RU01)); + int16x8_t _weight_hc_RU2 = vmovl_s8(vget_low_s8(_weight_hc_RU23)); + int16x8_t _weight_hc_RU3 = vmovl_s8(vget_high_s8(_weight_hc_RU23)); - _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU0))), _descale_hc_R); + float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU0))), _descale_hc_U); + float32x4_t _weight_hc_R_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU1))), _descale_hc_R); + float32x4_t _weight_hc_U_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU1))), _descale_hc_U); + float32x4_t _weight_hc_R_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU2))), _descale_hc_R); + float32x4_t _weight_hc_U_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU2))), _descale_hc_U); + float32x4_t _weight_hc_R_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU3))), _descale_hc_R); + float32x4_t _weight_hc_U_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU3))), _descale_hc_U); + + _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); + _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); + _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); + _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); - hidden_ptr += 4; weight_hc_int8_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = *hidden_ptr++; + float h_cont = hidden_state[i]; - float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); + float32x4_t _h_cont = vdupq_n_f32(h_cont); - float16x8_t _weight_hc_RU = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))), _descale_hc_RU); + int16x8_t _weight_hc_RU = vmovl_s8(vld1_s8(weight_hc_int8_RUN)); + float32x4_t _weight_hc_R = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_RU))), _descale_hc_R); + float32x4_t _weight_hc_U = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_RU))), _descale_hc_U); - _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); + _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); + _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); weight_hc_int8_RUN += 8; } - _RU = vaddq_f16(_RU, _sum1); - _sum2 = vaddq_f16(_sum2, _sum3); - _RU = vaddq_f16(_RU, _sum2); + _gru_R = vaddq_f32(_gru_R, _sum1); + _gru_U = vaddq_f32(_gru_U, _sum2); + _sum3 = vaddq_f32(_sum3, _sum5); + _sum4 = vaddq_f32(_sum4, _sum6); + _gru_R = vaddq_f32(_gru_R, _sum3); + _gru_U = vaddq_f32(_gru_U, _sum4); // sigmoid(R) // sigmoid(U) - float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); - float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); - - x -= size; - hidden_ptr = hidden_state; + _gru_R = sigmoid_ps(_gru_R); + _gru_U = sigmoid_ps(_gru_U); // gate new - float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); - float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); + float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); - float16x4_t _descale_xc_N = vld1_f16(weight_xc_int8_descales_RUN + 8); - float16x4_t _descale_hc_N = vld1_f16(weight_hc_int8_descales_RUN + 8); - float16x8_t _descale_xc_NN = vcombine_f16(_descale_xc_N, _descale_xc_N); - float16x8_t _descale_hc_NN = vcombine_f16(_descale_hc_N, _descale_hc_N); + float32x4_t _descale_xc_N = vld1q_f32(weight_xc_int8_descales_RUN + 8); + float32x4_t _descale_hc_N = vld1q_f32(weight_hc_int8_descales_RUN + 8); i = 0; for (; i + 3 < num_output; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v5.16b}, [%1], #16 \n" - "ld1 {v4.4s}, [%0], #16 \n" - "sxtl v0.8h, v5.8b \n" - "sxtl2 v2.8h, v5.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v2.8h, v2.8h \n" - "fcvtn v4.4h, v4.4s \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "mov v1.d[0], v0.d[1] \n" - "mov v3.d[0], v2.d[1] \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_int8_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(hidden_ptr), - "1"(weight_hc_int8_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6), - "w"(_descale_hc_NN) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - float16x8_t _weight_hc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))), _descale_hc_NN); - float16x8_t _weight_hc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))), _descale_hc_NN); - - float16x4_t _w0 = vget_low_f16(_weight_hc_N01); - float16x4_t _w1 = vget_high_f16(_weight_hc_N01); - float16x4_t _w2 = vget_low_f16(_weight_hc_N23); - float16x4_t _w3 = vget_high_f16(_weight_hc_N23); + int16x8_t _weight_hc_N01 = vmovl_s8(vget_low_s8(_weight_hc_N0123)); + int16x8_t _weight_hc_N23 = vmovl_s8(vget_high_s8(_weight_hc_N0123)); + float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N01))), _descale_hc_N); + float32x4_t _weight_hc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N01))), _descale_hc_N); + float32x4_t _weight_hc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_N23))), _descale_hc_N); + float32x4_t _weight_hc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_N23))), _descale_hc_N); - _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); + _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); - hidden_ptr += 4; weight_hc_int8_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = *hidden_ptr++; + float h_cont = hidden_state[i]; - float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); - float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN)))), _descale_hc_N); - _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N); + _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); weight_hc_int8_RUN += 4; } - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); + _gru_N = vaddq_f32(_gru_N, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _gru_N = vaddq_f32(_gru_N, _sum2); - _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); - _sum4 = vdup_n_f16((__fp16)0.f); - _sum5 = vdup_n_f16((__fp16)0.f); - _sum6 = vdup_n_f16((__fp16)0.f); + _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); i = 0; for (; i + 3 < size; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v5.16b}, [%1], #16 \n" - "ld1 {v4.4h}, [%0], #8 \n" - "sxtl v0.8h, v5.8b \n" - "sxtl2 v2.8h, v5.16b \n" - "scvtf v0.8h, v0.8h \n" - "scvtf v2.8h, v2.8h \n" - "fmul v0.8h, v0.8h, %12.8h \n" - "fmul v2.8h, v2.8h, %12.8h \n" - "mov v1.d[0], v0.d[1] \n" - "mov v3.d[0], v2.d[1] \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_int8_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(x), - "1"(weight_xc_int8_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6), - "w"(_descale_xc_NN) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - float16x8_t _weight_xc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))), _descale_xc_NN); - float16x8_t _weight_xc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))), _descale_xc_NN); - - float16x4_t _w0 = vget_low_f16(_weight_xc_N01); - float16x4_t _w1 = vget_high_f16(_weight_xc_N01); - float16x4_t _w2 = vget_low_f16(_weight_xc_N23); - float16x4_t _w3 = vget_high_f16(_weight_xc_N23); - - _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); + int16x8_t _weight_xc_N01 = vmovl_s8(vget_low_s8(_weight_xc_N0123)); + int16x8_t _weight_xc_N23 = vmovl_s8(vget_high_s8(_weight_xc_N0123)); + float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N01))), _descale_xc_N); + float32x4_t _weight_xc_N_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N01))), _descale_xc_N); + float32x4_t _weight_xc_N_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_N23))), _descale_xc_N); + float32x4_t _weight_xc_N_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_N23))), _descale_xc_N); + + _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); - x += 4; weight_xc_int8_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = *x++; + __fp16 xi = x[i]; - float16x4_t _xi = vdup_n_f16(xi); - float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN)))), _descale_xc_N); - _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N); + _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); weight_xc_int8_RUN += 4; } - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); + _gru_N = vaddq_f32(_gru_N, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _gru_N = vaddq_f32(_gru_N, _sum2); // tanh(N) - float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); + _gru_N = tanh_ps(_gru_N); float* gates_data = gates.row(q / 4); - vst1q_f32(gates_data, _U32); - vst1q_f32(gates_data + 4, _N32); + vst1q_f32(gates_data, _gru_U); + vst1q_f32(gates_data + 4, _gru_N); } #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) @@ -1473,23 +1479,23 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c const signed char* weight_xc_int8_RUN = weight_xc_int8.row(q / 4 + q % 4); const signed char* weight_hc_int8_RUN = weight_hc_int8.row(q / 4 + q % 4); - const __fp16* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); - const __fp16* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); + const float* weight_xc_int8_descales_RUN = weight_xc_int8_descales.row(q / 4 + q % 4); + const float* weight_hc_int8_descales_RUN = weight_hc_int8_descales.row(q / 4 + q % 4); - const __fp16 descale_xc_R = weight_xc_int8_descales_RUN[0]; - const __fp16 descale_xc_U = weight_xc_int8_descales_RUN[1]; - const __fp16 descale_xc_N = weight_xc_int8_descales_RUN[2]; + const float descale_xc_R = weight_xc_int8_descales_RUN[0]; + const float descale_xc_U = weight_xc_int8_descales_RUN[1]; + const float descale_xc_N = weight_xc_int8_descales_RUN[2]; - const __fp16 descale_hc_R = weight_hc_int8_descales_RUN[0]; - const __fp16 descale_hc_U = weight_hc_int8_descales_RUN[1]; - const __fp16 descale_hc_N = weight_hc_int8_descales_RUN[2]; + const float descale_hc_R = weight_hc_int8_descales_RUN[0]; + const float descale_hc_U = weight_hc_int8_descales_RUN[1]; + const float descale_hc_N = weight_hc_int8_descales_RUN[2]; - __fp16 R = bias_c_RUBNWN[0]; - __fp16 U = bias_c_RUBNWN[1]; + float R = (float)bias_c_RUBNWN[0]; + float U = (float)bias_c_RUBNWN[1]; for (int i = 0; i < size; i++) { - __fp16 xi = x[i]; + float xi = (float)x[i]; R += weight_xc_int8_RUN[0] * descale_xc_R * xi; U += weight_xc_int8_RUN[1] * descale_xc_U * xi; @@ -1499,7 +1505,7 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c for (int i = 0; i < num_output; i++) { - __fp16 h_cont = (__fp16)hidden_state[i]; + float h_cont = hidden_state[i]; R += weight_hc_int8_RUN[0] * descale_hc_R * h_cont; U += weight_hc_int8_RUN[1] * descale_hc_U * h_cont; @@ -1509,26 +1515,26 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c // sigmoid(R) // sigmoid(U) - float R32 = 1.f / (1.f + expf((float)-R)); - float U32 = 1.f / (1.f + expf((float)-U)); + R = 1.f / (1.f + expf(-R)); + U = 1.f / (1.f + expf(-U)); // gate new - __fp16 N = bias_c_RUBNWN[2]; + float N = (float)bias_c_RUBNWN[2]; for (int i = 0; i < num_output; i++) { - __fp16 h_cont = (__fp16)hidden_state[i]; + float h_cont = hidden_state[i]; N += weight_hc_int8_RUN[0] * descale_hc_N * h_cont; weight_hc_int8_RUN += 1; } - N = bias_c_RUBNWN[3] + (__fp16)R32 * N; + N = (float)bias_c_RUBNWN[3] + R * N; for (int i = 0; i < size; i++) { - __fp16 xi = x[i]; + float xi = (float)x[i]; N += weight_xc_int8_RUN[0] * descale_xc_N * xi; @@ -1536,12 +1542,12 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c } // tanh(N) - float N32 = tanhf((float)N); + N = tanhf(N); float* gates_data = gates.row(q / 4 + q % 4); - gates_data[0] = U32; - gates_data[1] = N32; + gates_data[0] = U; + gates_data[1] = N; } // h_t := (1 - update) .* new + update .* h_{t-1} @@ -2008,206 +2014,6 @@ int GRU_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector