diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index c50564685ac..a1a53903887 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -860,7 +860,8 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma float h_cont = hidden_state[i]; float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); + float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N); _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); weight_hc_int8_RUN += 4; @@ -907,7 +908,8 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma float xi = x[i]; float32x4_t _xi = vdupq_n_f32(xi); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); + float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N); _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); weight_xc_int8_RUN += 4; @@ -2152,7 +2154,8 @@ static int gru_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co float h_cont = hidden_state[i]; float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); + float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N); _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); weight_hc_int8_RUN += 4; @@ -2199,7 +2202,8 @@ static int gru_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co unsigned short xi = x[i]; float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); + float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N); _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); weight_xc_int8_RUN += 4; diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index 1cad989f502..4278d2289ea 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -1002,7 +1002,8 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c float h_cont = *hidden_ptr++; float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); - float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN)))), _descale_hc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); + float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_hc_N); _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); weight_hc_int8_RUN += 4; @@ -1076,7 +1077,8 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c __fp16 xi = *x++; float16x4_t _xi = vdup_n_f16(xi); - float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN)))), _descale_xc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); + float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_xc_N); _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); weight_xc_int8_RUN += 4; @@ -1411,7 +1413,8 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co float h_cont = hidden_state[i]; float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0])); + float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N); _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); weight_hc_int8_RUN += 4; @@ -1451,7 +1454,8 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co __fp16 xi = x[i]; float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0])); + float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N); _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); weight_xc_int8_RUN += 4; diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index fed6207330a..cc1e38620fa 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -521,7 +521,8 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M #if __ARM_NEON float32x4_t _xi = vdupq_n_f32(xi); - float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0])); + float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))); _weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc); _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi); #else @@ -573,7 +574,8 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M #if __ARM_NEON float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0])); + float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))); _weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc); _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); #else @@ -1436,7 +1438,8 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c unsigned short xi = x[i]; float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); - float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0])); + float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))); _weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc); _IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi); #else @@ -1490,7 +1493,8 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c #if __ARM_NEON float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0])); + float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))); _weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc); _IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); #else diff --git a/src/layer/arm/lstm_arm_asimdhp.cpp b/src/layer/arm/lstm_arm_asimdhp.cpp index d2ecc147e82..3c76d8dfaba 100644 --- a/src/layer/arm/lstm_arm_asimdhp.cpp +++ b/src/layer/arm/lstm_arm_asimdhp.cpp @@ -944,7 +944,8 @@ static int lstm_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, float16x4_t _xi = vdup_n_f16(xi); - float16x4_t _weight_xc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0])); + float16x4_t _weight_xc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))); _weight_xc_IFOG = vmul_f16(_weight_xc_IFOG, _descale_xc); _IFOG = vfma_f16(_IFOG, _weight_xc_IFOG, _xi); @@ -1012,7 +1013,8 @@ static int lstm_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); - float16x4_t _weight_hc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0])); + float16x4_t _weight_hc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))); _weight_hc_IFOG = vmul_f16(_weight_hc_IFOG, _descale_hc); _IFOG = vfma_f16(_IFOG, _weight_hc_IFOG, _h_cont); @@ -1220,7 +1222,8 @@ static int lstm_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c __fp16 xi = x[i]; float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0])); + float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))); _weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc); _IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi); @@ -1256,7 +1259,8 @@ static int lstm_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c float h_cont = hidden_state[i]; float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))))); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0])); + float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))); _weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc); _IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); diff --git a/src/layer/arm/rnn_arm.cpp b/src/layer/arm/rnn_arm.cpp index 90619a93604..c6e3ad18138 100644 --- a/src/layer/arm/rnn_arm.cpp +++ b/src/layer/arm/rnn_arm.cpp @@ -398,7 +398,8 @@ static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma for (; i < size; i++) { float32x4_t _x = vdupq_n_f32(x[i]); - float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0])); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc); _rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x); weight_xc_int8_ptr += 4; @@ -434,7 +435,8 @@ static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma for (; i < num_output; i++) { float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); - float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0])); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc); _rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state); weight_hc_int8_ptr += 4; @@ -1115,7 +1117,8 @@ static int rnn_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co for (; i < size; i++) { float32x4_t _x = bfloat2float(vdup_n_u16(x[i])); - float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0])); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc); _rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x); weight_xc_int8_ptr += 4; @@ -1151,7 +1154,8 @@ static int rnn_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co for (; i < num_output; i++) { float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); - float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0])); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc); _rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state); weight_hc_int8_ptr += 4; diff --git a/src/layer/arm/rnn_arm_asimdhp.cpp b/src/layer/arm/rnn_arm_asimdhp.cpp index cc93f6ddba1..a9e9b9acaa1 100644 --- a/src/layer/arm/rnn_arm_asimdhp.cpp +++ b/src/layer/arm/rnn_arm_asimdhp.cpp @@ -539,7 +539,8 @@ static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c for (; i < size; i++) { float16x4_t _x = vdup_n_f16(x[i]); - float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr)))), _descale_xc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0])); + float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_xc); _rnn_H = vfma_f16(_rnn_H, _weight_xc, _x); weight_xc_int8_ptr += 4; @@ -568,7 +569,8 @@ static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c for (; i < num_output; i++) { float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]); - float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr)))), _descale_hc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0])); + float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_hc); _rnn_H = vfma_f16(_rnn_H, _weight_hc, _hidden_state); weight_hc_int8_ptr += 4; @@ -706,7 +708,8 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co for (; i < size; i++) { float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); - float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0])); + float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc); _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); weight_xc_int8_ptr += 4; @@ -735,7 +738,8 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co for (; i < num_output; i++) { float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); - float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc); + int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0])); + float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc); _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); weight_hc_int8_ptr += 4;