diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9d42efb0d0b03..91a6edca60011 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3558,9 +3558,49 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +static __device__ float rope_ntkv2_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / min(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +struct rope_corr_factors { + float v[4]; +}; + +// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static __device__ float rope_ntkv2( + const float theta_base, + const float theta_linear, + const float theta_ntk, + const rope_corr_factors corr_factors, + const int64_t i0, + const float ntk_factor, + const float ext_factor) { + float ramp_mix; + float theta; + + ramp_mix = rope_ntkv2_ramp(corr_factors.v[0], corr_factors.v[1], i0) * ntk_factor; + theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; + + ramp_mix = rope_ntkv2_ramp(corr_factors.v[2], corr_factors.v[3], i0) * ext_factor; + theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; + return theta; +} + // rope == RoPE == rotary positional embedding -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale) { +static __global__ void rope_f32( + const float * x, + float * dst, + const int ncols, + const float freq_scale, + const float ntk_factor, + const float ext_factor, + const float theta_scale, + const float theta_ntk_scale, + const float p0, + const int p_delta_rows, + const rope_corr_factors corr_factors) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (col >= ncols) { @@ -3570,7 +3610,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int row = blockDim.y*blockIdx.y + threadIdx.y; const int i = row*ncols + col; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const float p = p0 + row / p_delta_rows; + const float theta_base = p*powf(theta_scale, col/2); + const float theta_linear = freq_scale * theta_base; + const float theta_ntk = p*powf(theta_ntk_scale, col/2); + const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor, ext_factor); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4234,13 +4278,26 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, - const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { +static void rope_f32_cuda( + const float * x, + float * dst, + const int ncols, + const int nrows, + const float freq_scale, + const float ntk_factor, + const float ext_factor, + const float theta_scale, + const float theta_ntk_scale, + const float p0, + const int p_delta_rows, + const rope_corr_factors corr_factors, + cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(num_blocks_x, nrows, 1); - rope_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); + rope_f32<<>>(x, dst, ncols, freq_scale, ntk_factor, ext_factor, theta_scale, + theta_ntk_scale, p0, p_delta_rows, corr_factors); } static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { @@ -4941,11 +4998,13 @@ inline void ggml_cuda_op_rope( const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; - // RoPE alteration for extended context - float freq_base, freq_scale; + // RoPE alteration for extended context + float freq_base, freq_scale, ntk_factor, ext_factor; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); @@ -4958,8 +5017,13 @@ inline void ggml_cuda_op_rope( const float block_p = max(p - (n_ctx - 2.f), 0.f); rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); } else { - const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; - rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); + const float p0 = (mode & 1) == 0 ? n_past : 0; + const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); + rope_corr_factors corr_factors; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors.v); + + rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, ext_factor, theta_scale, + theta_ntk_scale, p0, ne01, corr_factors, cudaStream_main); } (void) src1; diff --git a/ggml.c b/ggml.c index 8c5f7ac2641ef..8a57391811dcc 100644 --- a/ggml.c +++ b/ggml.c @@ -12012,11 +12012,6 @@ static void ggml_compute_forward_clamp( // ggml_compute_forward_rope -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -#define NTKV2_MAX_POS_EMB 2048 -#define NTKV2_CORRECTION_FACTOR(n_rot) (__builtin_logf(NTKV2_MAX_POS_EMB / ((n_rot) * 2 * (float)M_PI)) / 2) - static inline float rope_ntkv2_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / MIN(0.001f, high - low); return 1 - MIN(1, MAX(0, y)); @@ -12026,36 +12021,43 @@ static inline float rope_ntkv2_ramp(const float low, const float high, const int // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static float rope_ntkv2( const float theta_base, + const float theta_linear, const float theta_ntk, - const float dims_over_base, - const float freq_scale, + const float corr_factors[4], const int64_t i0, const float ntk_factor, - const float ext_factor, - const int n_dims) { + const float ext_factor) { + float ramp_mix; + float theta; + + ramp_mix = rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; + theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; + + ramp_mix = rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * ext_factor; + theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; + return theta; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) { + static const float max_pos_emb = 2048; + return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) { // Interpolation constants found experimentally for LLaMA (might not be totally optimal though) // Do not change unless there is a good reason for doing so! - static const float BETA_0 = 1.75f; - static const float BETA_1 = 1.25f; + static const float BETA_0 = 1.75f; + static const float BETA_1 = 1.25f; static const float GAMMA_0 = 16.0f; static const float GAMMA_1 = 2.0f; - static const float low_1p = NTKV2_CORRECTION_FACTOR(BETA_0); - static const float high_1p = NTKV2_CORRECTION_FACTOR(BETA_1); - static const float low_2p = NTKV2_CORRECTION_FACTOR(GAMMA_0); - static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1); - // start and end correction factors - const float low_1 = MAX(0, floorf(low_1p * dims_over_base)); - const float high_1 = MIN(n_dims - 1, ceilf(high_1p * dims_over_base)); - const float low_2 = MAX(0, floorf(low_2p * dims_over_base)); - const float high_2 = MIN(n_dims - 1, ceilf(high_2p * dims_over_base)); - - const float theta_linear = freq_scale * theta_base; - const float ramp_mix = rope_ntkv2_ramp(low_1, high_1, i0) * ntk_factor; - const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - const float ramp_final = rope_ntkv2_ramp(low_2, high_2, i0) * ext_factor; - return theta_mix * (1 - ramp_final) + theta_base * ramp_final; + factors[0] = MAX(0, floorf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base))); + factors[1] = MIN(n_dims - 1, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base))); + factors[2] = MAX(0, floorf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base))); + factors[3] = MIN(n_dims - 1, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base))); } static void ggml_compute_forward_rope_f32( @@ -12110,7 +12112,8 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - const float dims_over_base = n_dims / logf(freq_base); + float corr_factors[4]; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12152,8 +12155,9 @@ static void ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta = rope_ntkv2(theta_base, theta_ntk, dims_over_base, - freq_scale, i0, ntk_factor, ext_factor, n_dims); + const float theta_linear = freq_scale * theta_base; + const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, + i0, ntk_factor, ext_factor); const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -12250,7 +12254,8 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - const float dims_over_base = n_dims / logf(freq_base); + float corr_factors[4]; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12292,8 +12297,9 @@ static void ggml_compute_forward_rope_f16( } } if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta = rope_ntkv2(theta_base, theta_ntk, dims_over_base, - freq_scale, i0, ntk_factor, ext_factor, n_dims); + const float theta_linear = freq_scale * theta_base; + const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, + i0, ntk_factor, ext_factor); const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); diff --git a/ggml.h b/ggml.h index 459d217df8068..c2c6b7b1d376c 100644 --- a/ggml.h +++ b/ggml.h @@ -1211,6 +1211,9 @@ extern "C" { float ntk_factor, float ext_factor); + // compute correction factors for NTKv2 RoPE scaling + void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]); + // rotary position embedding backward, i.e compute dx from dy // a - dy GGML_API struct ggml_tensor * ggml_rope_back(