From 29b493036fe702c0a9d2448f1eb2bf6c0db116cd Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 23 Apr 2024 17:45:54 +0800 Subject: [PATCH] code clean --- src/layer/arm/rnn_arm.cpp | 14 +- src/layer/arm/rnn_arm.h | 2 - src/layer/arm/rnn_arm_asimdhp.cpp | 649 +++++++++++------------------- 3 files changed, 229 insertions(+), 436 deletions(-) diff --git a/src/layer/arm/rnn_arm.cpp b/src/layer/arm/rnn_arm.cpp index db92e87b885..3c5ecda64c7 100644 --- a/src/layer/arm/rnn_arm.cpp +++ b/src/layer/arm/rnn_arm.cpp @@ -656,12 +656,7 @@ int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blob, top_blob, opt); - else - return forward_fp16s(bottom_blob, top_blob, opt); - } + return forward_fp16s(bottom_blob, top_blob, opt); #endif #if NCNN_BF16 @@ -766,12 +761,7 @@ int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blobs, top_blobs, opt); - else - return forward_fp16s(bottom_blobs, top_blobs, opt); - } + return forward_fp16s(bottom_blobs, top_blobs, opt); #endif #if NCNN_BF16 diff --git a/src/layer/arm/rnn_arm.h b/src/layer/arm/rnn_arm.h index 97583e98ca6..38de2577d87 100644 --- a/src/layer/arm/rnn_arm.h +++ b/src/layer/arm/rnn_arm.h @@ -36,8 +36,6 @@ class RNN_arm : public RNN int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - int forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; - int forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif #if NCNN_BF16 int create_pipeline_bf16s(const Option& opt); diff --git a/src/layer/arm/rnn_arm_asimdhp.cpp b/src/layer/arm/rnn_arm_asimdhp.cpp index aef51ec608e..fbc0a31b39b 100644 --- a/src/layer/arm/rnn_arm_asimdhp.cpp +++ b/src/layer/arm/rnn_arm_asimdhp.cpp @@ -23,148 +23,6 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) -{ - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; - - // num_output - Mat gates(num_output, 4u, opt.workspace_allocator); - if (gates.empty()) - return -100; - - // unroll - for (int t = 0; t < T; t++) - { - int ti = reverse ? T - 1 - t : t; - - const __fp16* x = bottom_blob.row(ti); - - int nn_num_output = num_output >> 2; - int remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const __fp16* weight_xc_ptr = weight_xc.row(q / 4); - const __fp16* weight_hc_ptr = weight_hc.row(q / 4); - - float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i)); - float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); - float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4)); - float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8)); - float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12)); - _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); - - weight_xc_ptr += 16; - } - for (; i < size; i++) - { - float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); - float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); - _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); - - weight_xc_ptr += 4; - } - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); - float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4)); - float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8)); - float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12)); - _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); - - weight_hc_ptr += 16; - } - for (; i < num_output; i++) - { - float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); - float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); - _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); - - weight_hc_ptr += 4; - } - - _rnn_H = vaddq_f32(_rnn_H, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _rnn_H = vaddq_f32(_rnn_H, _sum2); - - _rnn_H = tanh_ps(_rnn_H); - - vst1q_f32((float*)gates + q, _rnn_H); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const __fp16* weight_xc_ptr = weight_xc.row(q / 4 + q % 4); - const __fp16* weight_hc_ptr = weight_hc.row(q / 4 + q % 4); - - float H = (float)(((const __fp16*)bias_c)[q]); - - for (int i = 0; i < size; i++) - { - H += (float)weight_xc_ptr[i] * (float)x[i]; - } - - for (int i = 0; i < num_output; i++) - { - H += (float)weight_hc_ptr[i] * hidden_state[i]; - } - - H = tanhf(H); - - gates[q] = H; - } - - __fp16* output_data = top_blob.row<__fp16>(ti); - - float* hidden_ptr = hidden_state; - - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - float32x4_t _rnn_H = vld1q_f32((float*)gates + q); - - vst1q_f32(hidden_ptr + q, _rnn_H); - vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - float H = gates[q]; - - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; - } - } - - return 0; -} - static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; @@ -380,9 +238,11 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const return 0; } -#if NCNN_INT8 -static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { + if (opt.use_fp16_arithmetic) + return rnn_fp16sa(bottom_blob, top_blob, reverse, weight_xc, bias_c, weight_hc, hidden_state, opt); + int size = bottom_blob.w; int T = bottom_blob.h; @@ -407,14 +267,8 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co { int q = qq * 4; - const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4); - const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4); - - const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4); - const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4); - - float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_ptr); - float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_ptr); + const __fp16* weight_xc_ptr = weight_xc.row(q / 4); + const __fp16* weight_hc_ptr = weight_hc.row(q / 4); float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); float32x4_t _sum1 = vdupq_n_f32(0.f); @@ -425,58 +279,48 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co for (; i + 3 < size; i += 4) { float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i)); - - int8x16_t _weight_xc = vld1q_s8(weight_xc_int8_ptr); - int16x8_t _weight_xc_01 = vmovl_s8(vget_low_s8(_weight_xc)); - int16x8_t _weight_xc_23 = vmovl_s8(vget_high_s8(_weight_xc)); - float32x4_t _weight_xc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_01))), _descale_xc); - float32x4_t _weight_xc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_01))), _descale_xc); - float32x4_t _weight_xc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_23))), _descale_xc); - float32x4_t _weight_xc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_23))), _descale_xc); - - _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc_0, _x, 0); + float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); + float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4)); + float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8)); + float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12)); + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); - weight_xc_int8_ptr += 16; + weight_xc_ptr += 16; } for (; i < size; i++) { float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); - float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); - weight_xc_int8_ptr += 4; + weight_xc_ptr += 4; } i = 0; for (; i + 3 < num_output; i += 4) { float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); - - int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); - int16x8_t _weight_hc_01 = vmovl_s8(vget_low_s8(_weight_hc)); - int16x8_t _weight_hc_23 = vmovl_s8(vget_high_s8(_weight_hc)); - float32x4_t _weight_hc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_01))), _descale_hc); - float32x4_t _weight_hc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_01))), _descale_hc); - float32x4_t _weight_hc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_23))), _descale_hc); - float32x4_t _weight_hc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_23))), _descale_hc); - - _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc_0, _hidden_state, 0); + float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); + float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4)); + float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8)); + float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12)); + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); - weight_hc_int8_ptr += 16; + weight_hc_ptr += 16; } for (; i < num_output; i++) { float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); - float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); - weight_hc_int8_ptr += 4; + weight_hc_ptr += 4; } _rnn_H = vaddq_f32(_rnn_H, _sum1); @@ -490,24 +334,19 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) { - const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4 + q % 4); - const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4 + q % 4); - const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4 + q % 4); - const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4 + q % 4); - - const float descale_xc = weight_xc_int8_descales_ptr[0]; - const float descale_hc = weight_hc_int8_descales_ptr[0]; + const __fp16* weight_xc_ptr = weight_xc.row(q / 4 + q % 4); + const __fp16* weight_hc_ptr = weight_hc.row(q / 4 + q % 4); float H = (float)(((const __fp16*)bias_c)[q]); for (int i = 0; i < size; i++) { - H += weight_xc_int8_ptr[i] * descale_xc * (float)x[i]; + H += (float)weight_xc_ptr[i] * (float)x[i]; } for (int i = 0; i < num_output; i++) { - H += weight_hc_int8_ptr[i] * descale_hc * hidden_state[i]; + H += (float)weight_hc_ptr[i] * hidden_state[i]; } H = tanhf(H); @@ -543,7 +382,7 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co return 0; } - +#if NCNN_INT8 static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; @@ -698,9 +537,176 @@ static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c } for (; i < size; i++) { - float16x4_t _x = vdup_n_f16(x[i]); - float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr)))), _descale_xc); - _rnn_H = vfma_f16(_rnn_H, _weight_xc, _x); + float16x4_t _x = vdup_n_f16(x[i]); + float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr)))), _descale_xc); + _rnn_H = vfma_f16(_rnn_H, _weight_xc, _x); + + weight_xc_int8_ptr += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i)); + + int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); + float16x8_t _weight_hc_01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc))), _descale_hc_2); + float16x8_t _weight_hc_23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc))), _descale_hc_2); + float16x4_t _weight_hc_0 = vget_low_f16(_weight_hc_01); + float16x4_t _weight_hc_1 = vget_high_f16(_weight_hc_01); + float16x4_t _weight_hc_2 = vget_low_f16(_weight_hc_23); + float16x4_t _weight_hc_3 = vget_high_f16(_weight_hc_23); + + _rnn_H = vfma_lane_f16(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfma_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfma_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfma_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3); + + weight_hc_int8_ptr += 16; + } + for (; i < num_output; i++) + { + float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]); + float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr)))), _descale_hc); + _rnn_H = vfma_f16(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_int8_ptr += 4; + } + + _rnn_H = vadd_f16(_rnn_H, _sum1); + _sum2 = vadd_f16(_sum2, _sum3); + _rnn_H = vadd_f16(_rnn_H, _sum2); + + float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_rnn_H)); + + vst1q_f32((float*)gates + q, _H32); + } + remain_num_output_start += nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 8 + (q % 8) / 4 + q % 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 8 + (q % 8) / 4 + q % 4); + const __fp16* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 8 + (q % 8) / 4 + q % 4); + const __fp16* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 8 + (q % 8) / 4 + q % 4); + + const __fp16 descale_xc = weight_xc_int8_descales_ptr[0]; + const __fp16 descale_hc = weight_hc_int8_descales_ptr[0]; + + __fp16 H = ((const __fp16*)bias_c)[q]; + + for (int i = 0; i < size; i++) + { + H += (__fp16)weight_xc_int8_ptr[i] * descale_xc * x[i]; + } + + for (int i = 0; i < num_output; i++) + { + H += (__fp16)weight_hc_int8_ptr[i] * descale_hc * (__fp16)hidden_state[i]; + } + + float H32 = tanhf((float)H); + + gates[q] = H32; + } + + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* hidden_ptr = hidden_state; + + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + + return 0; +} + +static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, Mat& hidden_state, const Option& opt) +{ + if (opt.use_fp16_arithmetic) + return rnn_fp16sa_int8(bottom_blob, top_blob, reverse, weight_xc_int8, weight_xc_int8_descales, bias_c, weight_hc_int8, weight_hc_int8_descales, hidden_state, opt); + + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const __fp16* x = bottom_blob.row(ti); + + int nn_num_output = num_output >> 2; + int remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4); + + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4); + + float32x4_t _descale_xc = vld1q_f32(weight_xc_int8_descales_ptr); + float32x4_t _descale_hc = vld1q_f32(weight_hc_int8_descales_ptr); + + float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i)); + + int8x16_t _weight_xc = vld1q_s8(weight_xc_int8_ptr); + int16x8_t _weight_xc_01 = vmovl_s8(vget_low_s8(_weight_xc)); + int16x8_t _weight_xc_23 = vmovl_s8(vget_high_s8(_weight_xc)); + float32x4_t _weight_xc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_01))), _descale_xc); + float32x4_t _weight_xc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_xc_23))), _descale_xc); + float32x4_t _weight_xc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_xc_23))), _descale_xc); + + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc_0, _x, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); + + weight_xc_int8_ptr += 16; + } + for (; i < size; i++) + { + float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); weight_xc_int8_ptr += 4; } @@ -708,67 +714,66 @@ static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c i = 0; for (; i + 3 < num_output; i += 4) { - float16x4_t _hidden_state = vcvt_f16_f32(vld1q_f32((const float*)hidden_state + i)); + float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); int8x16_t _weight_hc = vld1q_s8(weight_hc_int8_ptr); - float16x8_t _weight_hc_01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc))), _descale_hc_2); - float16x8_t _weight_hc_23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc))), _descale_hc_2); - float16x4_t _weight_hc_0 = vget_low_f16(_weight_hc_01); - float16x4_t _weight_hc_1 = vget_high_f16(_weight_hc_01); - float16x4_t _weight_hc_2 = vget_low_f16(_weight_hc_23); - float16x4_t _weight_hc_3 = vget_high_f16(_weight_hc_23); + int16x8_t _weight_hc_01 = vmovl_s8(vget_low_s8(_weight_hc)); + int16x8_t _weight_hc_23 = vmovl_s8(vget_high_s8(_weight_hc)); + float32x4_t _weight_hc_0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_01))), _descale_hc); + float32x4_t _weight_hc_2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(_weight_hc_23))), _descale_hc); + float32x4_t _weight_hc_3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(_weight_hc_23))), _descale_hc); - _rnn_H = vfma_lane_f16(_rnn_H, _weight_hc_0, _hidden_state, 0); - _sum1 = vfma_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1); - _sum2 = vfma_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2); - _sum3 = vfma_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3); + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc_0, _hidden_state, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); weight_hc_int8_ptr += 16; } for (; i < num_output; i++) { - float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]); - float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr)))), _descale_hc); - _rnn_H = vfma_f16(_rnn_H, _weight_hc, _hidden_state); + float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); weight_hc_int8_ptr += 4; } - _rnn_H = vadd_f16(_rnn_H, _sum1); - _sum2 = vadd_f16(_sum2, _sum3); - _rnn_H = vadd_f16(_rnn_H, _sum2); + _rnn_H = vaddq_f32(_rnn_H, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _rnn_H = vaddq_f32(_rnn_H, _sum2); - float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_rnn_H)); + _rnn_H = tanh_ps(_rnn_H); - vst1q_f32((float*)gates + q, _H32); + vst1q_f32((float*)gates + q, _rnn_H); } - remain_num_output_start += nn_num_output << 2; #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) { - const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 8 + (q % 8) / 4 + q % 4); - const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 8 + (q % 8) / 4 + q % 4); - const __fp16* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 8 + (q % 8) / 4 + q % 4); - const __fp16* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 8 + (q % 8) / 4 + q % 4); + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q / 4 + q % 4); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q / 4 + q % 4); + const float* weight_xc_int8_descales_ptr = weight_xc_int8_descales.row(q / 4 + q % 4); + const float* weight_hc_int8_descales_ptr = weight_hc_int8_descales.row(q / 4 + q % 4); - const __fp16 descale_xc = weight_xc_int8_descales_ptr[0]; - const __fp16 descale_hc = weight_hc_int8_descales_ptr[0]; + const float descale_xc = weight_xc_int8_descales_ptr[0]; + const float descale_hc = weight_hc_int8_descales_ptr[0]; - __fp16 H = ((const __fp16*)bias_c)[q]; + float H = (float)(((const __fp16*)bias_c)[q]); for (int i = 0; i < size; i++) { - H += (__fp16)weight_xc_int8_ptr[i] * descale_xc * x[i]; + H += weight_xc_int8_ptr[i] * descale_xc * (float)x[i]; } for (int i = 0; i < num_output; i++) { - H += (__fp16)weight_hc_int8_ptr[i] * descale_hc * (__fp16)hidden_state[i]; + H += weight_hc_int8_ptr[i] * descale_hc * hidden_state[i]; } - float H32 = tanhf((float)H); + H = tanhf(H); - gates[q] = H32; + gates[q] = H; } __fp16* output_data = top_blob.row<__fp16>(ti); @@ -1357,206 +1362,6 @@ int RNN_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector