diff --git a/src/layer/arm/rnn_arm_vfpv4.cpp b/src/layer/arm/rnn_arm_vfpv4.cpp new file mode 100644 index 00000000000..893f6e061b1 --- /dev/null +++ b/src/layer/arm/rnn_arm_vfpv4.cpp @@ -0,0 +1,30 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "arm_activation.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "rnn_int8.h" + +void rnn_int8_gate_output_vfpv4(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt) +{ + rnn_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt); +} + +} // namespace ncnn diff --git a/src/layer/arm/rnn_int8.h b/src/layer/arm/rnn_int8.h index bd4bd44878e..4e0d07d506b 100644 --- a/src/layer/arm/rnn_int8.h +++ b/src/layer/arm/rnn_int8.h @@ -17,6 +17,10 @@ void rnn_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_x void rnn_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt); #endif +#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2) +void rnn_int8_gate_output_vfpv4(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt); +#endif + static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt) { // TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -210,6 +214,102 @@ static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc } } +static void rnn_int8_gate_output(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2) + if (ncnn::cpu_support_arm_vfpv4()) + { + rnn_int8_gate_output_vfpv4(gates, hidden_state, top_blob, ti, elemtype, opt); + return; + } +#endif + + const int num_output = top_blob.w; + + float* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((const float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + + if (elemtype == 1) + { + // fp32 + vst1q_f32(output_data + q, _rnn_H); + } + if (elemtype == 2) + { + // fp16 + unsigned short* outptr = (unsigned short*)output_data + q; +#if (__ARM_FP & 2) +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "fcvtn v0.4h, %2.4s \n" + "st1 {v0.4h}, [%0] \n" + : "=r"(_rnn_H) // %0 + : "0"(outptr), + "w"(_rnn_H) + : "memory", "v0"); +#else // __aarch64__ + asm volatile( + "vcvt.f16.f32 d0, %q2 \n" + "vst1.u16 {d0}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_rnn_H) + : "memory", "q0"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_rnn_H)); +#endif // NCNN_GNU_INLINE_ASM +#else + outptr[q] = float32_to_float16(hidden_ptr[q]); + outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]); + outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]); + outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]); +#endif // (__ARM_FP & 2) + } + if (elemtype == 4) + { + // bf16 + vst1_u16((unsigned short*)output_data + q, float2bfloat(_rnn_H)); + } + } + remain_num_output_start += nn_num_output << 2; +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + + if (elemtype == 1) + { + output_data[q] = H; + } + if (elemtype == 2) + { + ((unsigned short*)output_data)[q] = float32_to_float16(H); + } + if (elemtype == 4) + { + ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); + } + } +} + static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) { // TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -490,59 +590,6 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de gates[q] = H; } - float* output_data = top_blob.row(ti); - - float* hidden_ptr = hidden_state; - -#if __ARM_NEON - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - float32x4_t _rnn_H = vld1q_f32((float*)gates + q); - - vst1q_f32(hidden_ptr + q, _rnn_H); - - if (elemtype == 1) - { - // fp32 - vst1q_f32(output_data + q, _rnn_H); - } - if (elemtype == 2) - { - // fp16 - vst1_u16((unsigned short*)output_data + q, (uint16x4_t)vcvt_f16_f32(_rnn_H)); - } - if (elemtype == 4) - { - // bf16 - vst1_u16((unsigned short*)output_data + q, float2bfloat(_rnn_H)); - } - } -#endif // __ARM_NEON - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - float H = gates[q]; - - hidden_ptr[q] = H; - - if (elemtype == 1) - { - output_data[q] = H; - } - if (elemtype == 2) - { - ((unsigned short*)output_data)[q] = float32_to_float16(H); - } - if (elemtype == 4) - { - ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); - } - } + rnn_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt); } }