From f30c57107439ebf5fa4cca7c880691f4088b8149 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Wed, 19 Jul 2023 00:22:05 -0400 Subject: [PATCH] llama: reduce code duplication in NTKv2 RoPE --- ggml-cuda.cu | 34 +++++------------------- ggml.c | 75 ++++++++++++++++++++++++++-------------------------- ggml.h | 3 +++ 3 files changed, 48 insertions(+), 64 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1aef214c543de9..3cb8a61acc354c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1875,14 +1875,14 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -static __device__ float ntkv2_ramp(const float low, const float high, const int i0) { +static __device__ float ggml_rope_ntkv2_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / min(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static __device__ float compute_ntkv2( +static __device__ float ggml_rope_ntkv2( float theta_base, float theta_linear, float theta_ntk, @@ -1893,10 +1893,10 @@ static __device__ float compute_ntkv2( float ramp_mix; float theta; - ramp_mix = ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; + ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - ramp_mix = ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor; + ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor; theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; return theta; } @@ -1918,7 +1918,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float theta_base = p*powf(theta_scale, col/2); const float theta_linear = freq_scale * theta_base; const float theta_ntk = p*powf(theta_ntk_scale, col/2); - const float theta = compute_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor, + const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor, extrapolation_factor); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -2974,13 +2974,6 @@ inline void ggml_cuda_op_mul_mat_cublas( (void) i1; } -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ntkv2_correction_factor(const int n_dims, const float n_rot, const float base) { - static const float max_pos_emb = 2048; - return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); -} - inline void ggml_cuda_op_rope( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -3018,21 +3011,8 @@ inline void ggml_cuda_op_rope( rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); } else { const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - - // Interpolation constants found experimentally for LLaMA (might not be totally optimal though) - // Do not change unless there is a good reason for doing so! - static const float BETA_0 = 1.75f; - static const float BETA_1 = 1.25f; - static const float GAMMA_0 = 16.0f; - static const float GAMMA_1 = 2.0f; - - // start and end correction factors - const float corr_factors[4] = { - max(0.0f, floorf(ntkv2_correction_factor(n_dims, BETA_0, freq_base))), - min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, BETA_1, freq_base))), - max(0.0f, floorf(ntkv2_correction_factor(n_dims, GAMMA_0, freq_base))), - min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, GAMMA_1, freq_base))), - }; + float corr_factors[4]; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main); diff --git a/ggml.c b/ggml.c index 549986f52aacbd..3f64b7ccf023bb 100644 --- a/ggml.c +++ b/ggml.c @@ -12093,29 +12093,42 @@ static void ggml_compute_forward_clamp( // ggml_compute_forward_rope -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -#define NTKV2_MAX_POS_EMB 2048 -#define NTKV2_CORRECTION_FACTOR(n_rot) (__builtin_logf(NTKV2_MAX_POS_EMB / ((n_rot) * 2 * (float)M_PI)) / 2) - // use -ffast-math so MIN and MAX are optimized to vminss and vmaxss __attribute__((optimize("-ffast-math"), always_inline)) -static inline float ntkv2_ramp(const float low, const float high, const int i0) { +static inline float ggml_rope_ntkv2_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / MIN(0.001f, high - low); return 1 - MIN(1, MAX(0, y)); } // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static float compute_ntkv2( +static float ggml_rope_ntkv2( float theta_base, + float theta_linear, float theta_ntk, - float dims_over_base, - float freq_scale, + const float corr_factors[4], int64_t i0, float ntk_factor, - float extrapolation_factor, - int n_dims) { + float extrapolation_factor) { + float ramp_mix; + float theta; + + ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; + theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; + + ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor; + theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; + return theta; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) { + static const float max_pos_emb = 2048; + return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) { // Interpolation constants found experimentally for LLaMA (might not be totally optimal though) // Do not change unless there is a good reason for doing so! static const float BETA_0 = 1.75f; @@ -12123,27 +12136,11 @@ static float compute_ntkv2( static const float GAMMA_0 = 16.0f; static const float GAMMA_1 = 2.0f; - static const float low_1p = NTKV2_CORRECTION_FACTOR(BETA_0); - static const float high_1p = NTKV2_CORRECTION_FACTOR(BETA_1); - static const float low_2p = NTKV2_CORRECTION_FACTOR(GAMMA_0); - static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1); - // start and end correction factors - const float low_1 = maxf(0.0f, floorf(low_1p * dims_over_base)); - const float high_1 = minf(n_dims - 1.0f, ceilf(high_1p * dims_over_base)); - const float low_2 = maxf(0.0f, floorf(low_2p * dims_over_base)); - const float high_2 = minf(n_dims - 1.0f, ceilf(high_2p * dims_over_base)); - - const float theta_linear = freq_scale * theta_base; - float ramp_mix; - float theta; - - ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor; - theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - - ramp_mix = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor; - theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; - return theta; + factors[0] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base))); + factors[1] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base))); + factors[2] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base))); + factors[3] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base))); } static void ggml_compute_forward_rope_f32( @@ -12201,7 +12198,8 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - const float dims_over_base = n_dims / logf(freq_base); + float corr_factors[4]; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12243,8 +12241,9 @@ static void ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base, - freq_scale, i0, ntk_factor, extrapolation_factor, n_dims); + const float theta_linear = freq_scale * theta_base; + const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, + i0, ntk_factor, extrapolation_factor); const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -12343,7 +12342,8 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - const float dims_over_base = n_dims / logf(freq_base); + float corr_factors[4]; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12385,8 +12385,9 @@ static void ggml_compute_forward_rope_f16( } } if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base, - freq_scale, i0, ntk_factor, extrapolation_factor, n_dims); + const float theta_linear = freq_scale * theta_base; + const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, + i0, ntk_factor, extrapolation_factor); const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); diff --git a/ggml.h b/ggml.h index 9dca37ec55e731..8c98b0ffac244d 100644 --- a/ggml.h +++ b/ggml.h @@ -1134,6 +1134,9 @@ extern "C" { float extrapolation_factor, int n_ctx); + // compute correction factors for NTKv2 RoPE scaling + void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]); + // rotary position embedding backward, i.e compute dx from dy // a - dy GGML_API struct ggml_tensor * ggml_rope_back(