From 5cc4072aadf2033148dcad83631308536ba8a2c4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 29 Apr 2024 17:23:30 +0800 Subject: [PATCH] opt++ --- src/layer/arm/rnn_int8.h | 52 ++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/src/layer/arm/rnn_int8.h b/src/layer/arm/rnn_int8.h index 285f5a72784..bd4bd44878e 100644 --- a/src/layer/arm/rnn_int8.h +++ b/src/layer/arm/rnn_int8.h @@ -292,31 +292,25 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de int32x4_t _sum3 = vdupq_n_s32(0); for (; i + 15 < size; i += 16) { - int32x4_t _xi01 = vreinterpretq_s32_s8(vld1q_s8(x + i)); - int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 0)); - int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 1)); - int8x16_t _xi2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 2)); - int8x16_t _xi3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_xi01, 3)); + int8x16_t _xi = vld1q_s8(x + i); int8x16_t _w0 = vld1q_s8(kptr); int8x16_t _w1 = vld1q_s8(kptr + 16); int8x16_t _w2 = vld1q_s8(kptr + 32); int8x16_t _w3 = vld1q_s8(kptr + 48); - _rnn_Hx0 = vdotq_s32(_rnn_Hx0, _w0, _xi0); - _sum1 = vdotq_s32(_sum1, _w1, _xi1); - _sum2 = vdotq_s32(_sum2, _w2, _xi2); - _sum3 = vdotq_s32(_sum3, _w3, _xi3); + _rnn_Hx0 = vdotq_laneq_s32(_rnn_Hx0, _w0, _xi, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w1, _xi, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w2, _xi, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _xi, 3); kptr += 64; } for (; i + 7 < size; i += 8) { - int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i)); - int8x16_t _xi0 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 0)); - int8x16_t _xi1 = vreinterpretq_s8_s32(vdupq_lane_s32(_xi01, 1)); + int8x8_t _xi = vld1_s8(x + i); int8x16_t _w0 = vld1q_s8(kptr); int8x16_t _w1 = vld1q_s8(kptr + 16); - _rnn_Hx0 = vdotq_s32(_rnn_Hx0, _w0, _xi0); - _sum1 = vdotq_s32(_sum1, _w1, _xi1); + _rnn_Hx0 = vdotq_lane_s32(_rnn_Hx0, _w0, _xi, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1); kptr += 32; } @@ -327,9 +321,9 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de for (; i + 3 < size; i += 4) { #if __ARM_FEATURE_DOTPROD - int8x16_t _xi = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(x + i)), 0)); + int8x8_t _xi = vld1_s8(x + i); int8x16_t _w = vld1q_s8(kptr); - _rnn_Hx0 = vdotq_s32(_rnn_Hx0, _w, _xi); + _rnn_Hx0 = vdotq_lane_s32(_rnn_Hx0, _w, _xi, 0); #else int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); @@ -372,31 +366,25 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de _sum3 = vdupq_n_s32(0); for (; i + 15 < num_output; i += 16) { - int32x4_t _h_cont01 = vreinterpretq_s32_s8(vld1q_s8(hs + i)); - int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 0)); - int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 1)); - int8x16_t _h_cont2 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 2)); - int8x16_t _h_cont3 = vreinterpretq_s8_s32(vdupq_laneq_s32(_h_cont01, 3)); + int8x16_t _h_cont = vld1q_s8(hs + i); int8x16_t _w0 = vld1q_s8(kptr); int8x16_t _w1 = vld1q_s8(kptr + 16); int8x16_t _w2 = vld1q_s8(kptr + 32); int8x16_t _w3 = vld1q_s8(kptr + 48); - _rnn_Hh0 = vdotq_s32(_rnn_Hh0, _w0, _h_cont0); - _sum1 = vdotq_s32(_sum1, _w1, _h_cont1); - _sum2 = vdotq_s32(_sum2, _w2, _h_cont2); - _sum3 = vdotq_s32(_sum3, _w3, _h_cont3); + _rnn_Hh0 = vdotq_laneq_s32(_rnn_Hh0, _w0, _h_cont, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w1, _h_cont, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w2, _h_cont, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _h_cont, 3); kptr += 64; } for (; i + 7 < num_output; i += 8) { - int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i)); - int8x16_t _h_cont0 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 0)); - int8x16_t _h_cont1 = vreinterpretq_s8_s32(vdupq_lane_s32(_h_cont01, 1)); + int8x8_t _h_cont = vld1_s8(hs + i); int8x16_t _w0 = vld1q_s8(kptr); int8x16_t _w1 = vld1q_s8(kptr + 16); - _rnn_Hh0 = vdotq_s32(_rnn_Hh0, _w0, _h_cont0); - _sum1 = vdotq_s32(_sum1, _w1, _h_cont1); + _rnn_Hh0 = vdotq_lane_s32(_rnn_Hh0, _w0, _h_cont, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1); kptr += 32; } @@ -407,9 +395,9 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de for (; i + 3 < num_output; i += 4) { #if __ARM_FEATURE_DOTPROD - int8x16_t _h_cont = vreinterpretq_s8_s32(vdupq_lane_s32(vreinterpret_s32_s8(vld1_s8(hs + i)), 0)); + int8x8_t _h_cont = vld1_s8(hs + i); int8x16_t _w = vld1q_s8(kptr); - _rnn_Hh0 = vdotq_s32(_rnn_Hh0, _w, _h_cont); + _rnn_Hh0 = vdotq_lane_s32(_rnn_Hh0, _w, _h_cont, 0); #else int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0));