From 12e2ebc56ac86b4f1e9e2047c3048a905b5b8ec8 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Mon, 5 Aug 2024 08:50:50 -0700 Subject: [PATCH] ggml : make GeLU faster and more accurate on CPU This change makes GeLU go 8x faster on AVX2, 3x faster on Apple Silicon, and 2x faster on Threadripper. It is the world's most popular activation function, used by models such as Whisper and Gemma, where it can lead to a noticeable improvement in performance, because the GeLU op is the most time-consuming usually of any operation except for matrix multiplication In addition to improving performance this change also improves accuracy. On ARM64 and AMD64 systems, we no longer need to rely on a 16-bit lookup table. We're now using SIMD instead. The GeLU lookup table is still here except it's been converted from fp16 to bf16. This helps align inference more with training possibly, but it helps us avoid the two extra lookups into the fp16 table. Therefore this change should have a positive impact on performance for platforms like OpenPOWER and RISC-V too. --- ggml/src/ggml.c | 389 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 345 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 910981e4a37ba6..0472ac58191df1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -255,7 +255,7 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { } #define GGML_DEBUG 0 -#define GGML_GELU_FP16 +#define GGML_GELU_BF16 #define GGML_GELU_QUICK_FP16 #define GGML_SOFT_MAX_UNROLL 4 @@ -390,8 +390,10 @@ typedef double ggml_float; // global data // -// precomputed gelu table for f16 (128 KB) -static ggml_fp16_t ggml_table_gelu_f16[1 << 16]; +#ifdef GGML_GELU_BF16 +// precomputed gelu table in brain16 (128 KB) +static ggml_bf16_t ggml_table_gelu_f16[1 << 16]; +#endif // precomputed quick gelu table for f16 (128 KB) static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; @@ -1839,6 +1841,19 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) #endif +// for GeLU and SiLU +#ifdef __FMA__ +#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) +#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) +#define MADD256(x, y, z) _mm256_fmadd_ps(x, y, z) +#define NMADD256(x, y, z) _mm256_fnmadd_ps(x, y, z) +#else +#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) +#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) +#define MADD256(x, y, z) _mm256_add_ps(_mm256_mul_ps(x, y), z) +#define NMADD256(x, y, z) _mm256_sub_ps(z, _mm256_mul_ps(x, y)) +#endif + // // ggml context // @@ -2320,55 +2335,345 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -static const float GELU_COEF_A = 0.044715f; +//////////////////////////////////////////////////////////////////////////////// +// There's always room for GeLU + +static const float GELU_COEF_A = .044715f; static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float SQRT_2_OVER_PI = .79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + return .5f*x*(1.f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = ggml_table_gelu_f16[i16[i]]; - } +#if defined(__ARM_NEON) && defined(__aarch64__) + +/* Approximation for single-precision vector tanh (2.58 ULP) + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static float32x4_t +ggml_vtanhf(float32x4_t x) +{ + const uint32x4_t ix = vreinterpretq_u32_f32(x); + const float32x4_t ax = vabsq_f32(x); + const uint32x4_t iax = vreinterpretq_u32_f32(ax); + const uint32x4_t sign = veorq_u32(ix, iax); + const uint32x4_t is_boring = vcgtq_u32(iax, vdupq_n_u32(0x41102cb3)); + const float32x4_t boring = + vreinterpretq_f32_u32(vorrq_u32(sign, vdupq_n_u32(0x3f800000))); + const uint32x4_t special = vcgtq_u32(iax, vdupq_n_u32(0x7f800000)); + const float32x4_t ex = vmulq_n_f32(x, 2); + const float32x4_t e = { 0x1.715476p+0f, 0x1.62e4p-1f, 0x1.7f7d1cp-20f }; + const float32x4_t j = + vsubq_f32(vfmaq_laneq_f32(vdupq_n_f32(0x1.8p23f), ex, e, 0), + vdupq_n_f32(0x1.8p23f)); + const int32x4_t i = vcvtq_s32_f32(j); + const float32x4_t f = vfmsq_laneq_f32(ex, j, e, 1); + const float32x4_t f1 = vfmsq_laneq_f32(f, j, e, 2); + const float32x4_t f2 = vmulq_f32(f1, f1); + const float32x4_t f4 = vmulq_f32(f2, f2); + const float32x4_t p01 = + vfmaq_f32(vdupq_n_f32(0x1.fffffep-2), vdupq_n_f32(0x1.5554aep-3), f1); + const float32x4_t p23 = + vfmaq_f32(vdupq_n_f32(0x1.555736p-5), vdupq_n_f32(0x1.12287cp-7), f1); + const float32x4_t p03 = vfmaq_f32(p01, p23, f2); + const float32x4_t p = vfmaq_f32(p03, vdupq_n_f32(0x1.6b55a2p-10), f4); + const float32x4_t p2 = vfmaq_f32(f1, f2, p); + const int32x4_t u = vaddq_s32(vshlq_n_s32(i, 23), vdupq_n_s32(0x3f800000)); + const float32x4_t t = vreinterpretq_f32_s32(u); + const float32x4_t q = vfmaq_f32(vsubq_f32(t, vdupq_n_f32(1)), p2, t); + const float32x4_t y = vdivq_f32(q, vaddq_f32(q, vdupq_n_f32(2))); + const float32x4_t result = vbslq_f32(is_boring, boring, y); + if (!vpaddd_u64(vreinterpretq_u64_u32(special))) + return result; + return (float32x4_t){ special[0] ? tanhf(x[0]) : result[0], + special[1] ? tanhf(x[1]) : result[1], + special[2] ? tanhf(x[2]) : result[2], + special[3] ? tanhf(x[3]) : result[3] }; +} + +inline static float32x4_t +ggml_vgeluf(float32x4_t x) +{ + const float32x4_t one = vdupq_n_f32(1); + const float32x4_t half = vdupq_n_f32(.5); + const float32x4_t coef_a = vdupq_n_f32(GELU_COEF_A); + const float32x4_t sqrt_2_over_pi = vdupq_n_f32(SQRT_2_OVER_PI); + const float32x4_t x_squared = vmulq_f32(x, x); + const float32x4_t ax2 = vmulq_f32(coef_a, x_squared); + const float32x4_t one_plus_ax2 = vaddq_f32(one, ax2); + const float32x4_t inner = + vmulq_f32(vmulq_f32(sqrt_2_over_pi, x), one_plus_ax2); + const float32x4_t tanh_inner = ggml_vtanhf(inner); + const float32x4_t one_plus_tanh = vaddq_f32(one, tanh_inner); + return vmulq_f32(vmulq_f32(half, x), one_plus_tanh); } -#ifdef GGML_GELU_FP16 -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - if (x[i] <= -10.0f) { - y[i] = 0.0f; - } else if (x[i] >= 10.0f) { - y[i] = x[i]; - } else { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); - } - } +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + +/* Approximation for single-precision vector tanh(x) using a + branchless algorithm that offers a maximum error of 4 ULP + + 108638843x off by one errors + 18273656x 2 to 3 ulp errors + 124x 4 ulp erors (e.g. 0.203652 [3e508a10]) + 1x sign flip + + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static __m512 +ggml_vtanhf(__m512 x) +{ + const __m512 sign_mask = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000)); + const __m512 one = _mm512_set1_ps(1); + const __m512 two = _mm512_set1_ps(2); + const __m512 ax = _mm512_abs_ps(x); + const __m512 sign = _mm512_and_ps(x, sign_mask); + const __mmask16 is_boring = + _mm512_cmp_ps_mask(ax, _mm512_set1_ps(0x1.205966p+3), _CMP_GT_OQ); + const __m512 boring = _mm512_or_ps(sign, one); + const __m512 ex = _mm512_mul_ps(x, two); + const __m512 j = _mm512_fmadd_ps( + ex, _mm512_set1_ps(0x1.715476p+0f), _mm512_set1_ps(0x1.8p23f)); + const __m512 jj = _mm512_sub_ps(j, _mm512_set1_ps(0x1.8p23f)); + const __m512i i = _mm512_cvttps_epi32(jj); + const __m512 f = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.62e4p-1f), jj, ex); + const __m512 f1 = _mm512_fnmadd_ps(_mm512_set1_ps(0x1.7f7d1cp-20f), jj, f); + const __m512 f2 = _mm512_mul_ps(f1, f1); + const __m512 f4 = _mm512_mul_ps(f2, f2); + const __m512 p01 = _mm512_fmadd_ps( + f1, _mm512_set1_ps(0x1.5554aep-3), _mm512_set1_ps(0x1.fffffep-2)); + const __m512 p23 = _mm512_fmadd_ps( + f1, _mm512_set1_ps(0x1.12287cp-7), _mm512_set1_ps(0x1.555736p-5)); + const __m512 p03 = _mm512_fmadd_ps(f2, p23, p01); + const __m512 p = _mm512_fmadd_ps(f4, _mm512_set1_ps(0x1.6b55a2p-10), p03); + const __m512 p2 = _mm512_fmadd_ps(f2, p, f1); + const __m512i u = + _mm512_add_epi32(_mm512_slli_epi32(i, 23), _mm512_set1_epi32(0x3f800000)); + const __m512 t = _mm512_castsi512_ps(u); + const __m512 q = _mm512_fmadd_ps(p2, t, _mm512_sub_ps(t, one)); + const __m512 y = _mm512_div_ps(q, _mm512_add_ps(q, two)); + return _mm512_mask_blend_ps(is_boring, y, boring); +} + +inline static __m512 +ggml_vgeluf(__m512 x) +{ + const __m512 one = _mm512_set1_ps(1); + const __m512 half = _mm512_set1_ps(.5); + const __m512 coef_a = _mm512_set1_ps(GELU_COEF_A); + const __m512 sqrt_2_over_pi = _mm512_set1_ps(SQRT_2_OVER_PI); + const __m512 x_squared = _mm512_mul_ps(x, x); + const __m512 ax2 = _mm512_mul_ps(coef_a, x_squared); + const __m512 one_plus_ax2 = _mm512_add_ps(one, ax2); + const __m512 inner = + _mm512_mul_ps(_mm512_mul_ps(sqrt_2_over_pi, x), one_plus_ax2); + const __m512 tanh_inner = ggml_vtanhf(inner); + const __m512 one_plus_tanh = _mm512_add_ps(one, tanh_inner); + return _mm512_mul_ps(_mm512_mul_ps(half, x), one_plus_tanh); +} + +#elif defined(__AVX2__) + +/* Approximation for single-precision vector tanh(x) using a + branchless algorithm that offers a maximum error of 4 ULP + + With fused multiply add: + + 108638843x off by one errors + 18273656x 2 to 3 ulp errors + 124x 4 ulp erors (e.g. 0.203652 [3e508a10]) + 1x sign flip + + Without fused multiply add: + + 108479590x off by one errors + 18209645x 2 to 3 ulp errors + 70x 4 ulp errors (e.g. 0.205979 [3e52ec19]) + 1x sign flip + + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static __m256 +ggml_vtanhf(__m256 x) +{ + const __m256 abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); + const __m256 one = _mm256_set1_ps(1); + const __m256 two = _mm256_set1_ps(2); + const __m256 ax = _mm256_and_ps(x, abs_mask); + const __m256 sign = _mm256_and_ps(x, _mm256_set1_ps(-0.f)); + const __m256 is_boring = + _mm256_cmp_ps(ax, _mm256_set1_ps(0x1.205966p+3), _CMP_GT_OQ); + const __m256 boring = _mm256_or_ps(sign, one); + const __m256 ex = _mm256_mul_ps(x, two); + const __m256 j = + MADD256(ex, _mm256_set1_ps(0x1.715476p+0f), _mm256_set1_ps(0x1.8p23f)); + const __m256 jj = _mm256_sub_ps(j, _mm256_set1_ps(0x1.8p23f)); + const __m256i i = _mm256_cvttps_epi32(jj); + const __m256 f = NMADD256(_mm256_set1_ps(0x1.62e4p-1f), jj, ex); + const __m256 f1 = NMADD256(_mm256_set1_ps(0x1.7f7d1cp-20f), jj, f); + const __m256 f2 = _mm256_mul_ps(f1, f1); + const __m256 f4 = _mm256_mul_ps(f2, f2); + const __m256 p01 = + MADD256(f1, _mm256_set1_ps(0x1.5554aep-3), _mm256_set1_ps(0x1.fffffep-2)); + const __m256 p23 = + MADD256(f1, _mm256_set1_ps(0x1.12287cp-7), _mm256_set1_ps(0x1.555736p-5)); + const __m256 p03 = MADD256(f2, p23, p01); + const __m256 p = MADD256(f4, _mm256_set1_ps(0x1.6b55a2p-10), p03); + const __m256 p2 = MADD256(f2, p, f1); + const __m256i u = + _mm256_add_epi32(_mm256_slli_epi32(i, 23), _mm256_set1_epi32(0x3f800000)); + const __m256 t = _mm256_castsi256_ps(u); + const __m256 q = MADD256(p2, t, _mm256_sub_ps(t, one)); + const __m256 y = _mm256_div_ps(q, _mm256_add_ps(q, two)); + return _mm256_or_ps(_mm256_and_ps(is_boring, boring), + _mm256_andnot_ps(is_boring, y)); +} + +inline static __m256 +ggml_vgeluf(__m256 x) +{ + const __m256 one = _mm256_set1_ps(1); + const __m256 half = _mm256_set1_ps(.5); + const __m256 coef_a = _mm256_set1_ps(GELU_COEF_A); + const __m256 sqrt_2_over_pi = _mm256_set1_ps(SQRT_2_OVER_PI); + const __m256 x_squared = _mm256_mul_ps(x, x); + const __m256 ax2 = _mm256_mul_ps(coef_a, x_squared); + const __m256 one_plus_ax2 = _mm256_add_ps(one, ax2); + const __m256 inner = + _mm256_mul_ps(_mm256_mul_ps(sqrt_2_over_pi, x), one_plus_ax2); + const __m256 tanh_inner = ggml_vtanhf(inner); + const __m256 one_plus_tanh = _mm256_add_ps(one, tanh_inner); + return _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh); +} + +#elif defined(__SSE2__) + +/* Approximation for single-precision vector tanh(x) using a + branchless algorithm that offers a maximum error of 4 ULP + + Without fused multiply add: + + 108479590x off by one errors + 18209645x 2 to 3 ulp errors + 70x 4 ulp errors (e.g. 0.205979 [3e52ec19]) + 1x sign flip + + With fused multiply add: + + 108638843x off by one errors + 18273656x 2 to 3 ulp errors + 124x 4 ulp erors (e.g. 0.203652 [3e508a10]) + 1x sign flip + + There is no support for signed zero whose sign is removed + There is no support for floating point exception handling + This code is based on the ARM Limited Optimized Routines. */ +inline static __m128 +ggml_vtanhf(__m128 x) +{ + const __m128 abs_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF)); + const __m128 one = _mm_set1_ps(1); + const __m128 two = _mm_set1_ps(2); + const __m128 ax = _mm_and_ps(x, abs_mask); + const __m128 sign = _mm_and_ps(x, _mm_set1_ps(-0.f)); + const __m128 is_boring = _mm_cmpgt_ps(ax, _mm_set1_ps(0x1.205966p+3)); + const __m128 boring = _mm_or_ps(sign, one); + const __m128 ex = _mm_mul_ps(x, two); + const __m128 j = + MADD128(ex, _mm_set1_ps(0x1.715476p+0f), _mm_set1_ps(0x1.8p23f)); + const __m128 jj = _mm_sub_ps(j, _mm_set1_ps(0x1.8p23f)); + const __m128i i = _mm_cvttps_epi32(jj); + const __m128 f = NMADD128(_mm_set1_ps(0x1.62e4p-1f), jj, ex); + const __m128 f1 = NMADD128(_mm_set1_ps(0x1.7f7d1cp-20f), jj, f); + const __m128 f2 = _mm_mul_ps(f1, f1); + const __m128 f4 = _mm_mul_ps(f2, f2); + const __m128 p01 = + MADD128(f1, _mm_set1_ps(0x1.5554aep-3), _mm_set1_ps(0x1.fffffep-2)); + const __m128 p23 = + MADD128(f1, _mm_set1_ps(0x1.12287cp-7), _mm_set1_ps(0x1.555736p-5)); + const __m128 p03 = MADD128(f2, p23, p01); + const __m128 p = MADD128(f4, _mm_set1_ps(0x1.6b55a2p-10), p03); + const __m128 p2 = MADD128(f2, p, f1); + const __m128i u = + _mm_add_epi32(_mm_slli_epi32(i, 23), _mm_set1_epi32(0x3f800000)); + const __m128 t = _mm_castsi128_ps(u); + const __m128 q = MADD128(p2, t, _mm_sub_ps(t, one)); + const __m128 y = _mm_div_ps(q, _mm_add_ps(q, two)); + return _mm_or_ps(_mm_and_ps(is_boring, boring), + _mm_andnot_ps(is_boring, y)); +} + +inline static __m128 +ggml_vgeluf(__m128 x) +{ + const __m128 one = _mm_set1_ps(1); + const __m128 half = _mm_set1_ps(.5); + const __m128 coef_a = _mm_set1_ps(GELU_COEF_A); + const __m128 sqrt_2_over_pi = _mm_set1_ps(SQRT_2_OVER_PI); + const __m128 x_squared = _mm_mul_ps(x, x); + const __m128 ax2 = _mm_mul_ps(coef_a, x_squared); + const __m128 one_plus_ax2 = _mm_add_ps(one, ax2); + const __m128 inner = + _mm_mul_ps(_mm_mul_ps(sqrt_2_over_pi, x), one_plus_ax2); + const __m128 tanh_inner = ggml_vtanhf(inner); + const __m128 one_plus_tanh = _mm_add_ps(one, tanh_inner); + return _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh); } + +#endif + +static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + int i = 0; + // vectorized arm gelu goes 10x faster with ~30 bits of accuracy +#if defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, ggml_vgeluf(vld1q_f32(x + i))); + } +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, ggml_vgeluf(_mm512_loadu_ps(x + i))); + } +#elif defined(__AVX2__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, ggml_vgeluf(_mm256_loadu_ps(x + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, ggml_vgeluf(_mm_loadu_ps(x + i))); + } +#endif + for (; i < n; ++i) { +#ifdef GGML_GELU_BF16 + // gelu brain lut goes 5x faster with ~16 bits of accuracy + // this is the only game in town, if not on arm64 or amd64 + union { + float f; + uint32_t i; + } pun32; + union { + uint16_t i; + ggml_bf16_t f; + } pun16; + pun32.f = x[i]; + int k = (pun32.i + (0x7fff + ((pun32.i >> 16) & 1))) >> 16; + pun16.f = ggml_table_gelu_f16[k]; + pun32.i = (uint32_t)pun16.i << 16; + y[i] = pun32.f; #else -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { + // computes canonical gelu approximation with ~32 bit accuracy y[i] = ggml_gelu_f32(x[i]); +#endif } } -#endif inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} - #ifdef GGML_GELU_QUICK_FP16 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { uint16_t t; @@ -2535,14 +2840,6 @@ inline static __m256 ggml_v_silu(__m256 x) { #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON -#if defined(__FMA__) -#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) -#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) -#else -#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) -#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) -#endif - // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps // numbers above 88.38 will flush to infinity @@ -3481,10 +3778,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { union { uint16_t u16; ggml_fp16_t fp16; + ggml_bf16_t bf16; } u = {i}; float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); - ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); +#ifdef GGML_GELU_BF16 + ggml_table_gelu_f16[i] = ggml_compute_fp32_to_bf16( + ggml_gelu_f32(ggml_compute_bf16_to_fp32(u.bf16))); +#endif } const uint64_t t_end = ggml_time_us(); UNUSED(t_end);