Skip to content

Commit

Permalink
llama: reduce code duplication in NTKv2 RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre committed Jul 19, 2023
1 parent 03a715f commit f30c571
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 64 deletions.
34 changes: 7 additions & 27 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1875,14 +1875,14 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}

static __device__ float ntkv2_ramp(const float low, const float high, const int i0) {
static __device__ float ggml_rope_ntkv2_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ float compute_ntkv2(
static __device__ float ggml_rope_ntkv2(
float theta_base,
float theta_linear,
float theta_ntk,
Expand All @@ -1893,10 +1893,10 @@ static __device__ float compute_ntkv2(
float ramp_mix;
float theta;

ramp_mix = ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;

ramp_mix = ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
return theta;
}
Expand All @@ -1918,7 +1918,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols,
const float theta_base = p*powf(theta_scale, col/2);
const float theta_linear = freq_scale * theta_base;
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
const float theta = compute_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor,
const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor,
extrapolation_factor);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
Expand Down Expand Up @@ -2974,13 +2974,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
(void) i1;
}

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ntkv2_correction_factor(const int n_dims, const float n_rot, const float base) {
static const float max_pos_emb = 2048;
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}

inline void ggml_cuda_op_rope(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
Expand Down Expand Up @@ -3018,21 +3011,8 @@ inline void ggml_cuda_op_rope(
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else {
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);

// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
// Do not change unless there is a good reason for doing so!
static const float BETA_0 = 1.75f;
static const float BETA_1 = 1.25f;
static const float GAMMA_0 = 16.0f;
static const float GAMMA_1 = 2.0f;

// start and end correction factors
const float corr_factors[4] = {
max(0.0f, floorf(ntkv2_correction_factor(n_dims, BETA_0, freq_base))),
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, BETA_1, freq_base))),
max(0.0f, floorf(ntkv2_correction_factor(n_dims, GAMMA_0, freq_base))),
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, GAMMA_1, freq_base))),
};
float corr_factors[4];
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);

rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main);
Expand Down
75 changes: 38 additions & 37 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12093,57 +12093,54 @@ static void ggml_compute_forward_clamp(

// ggml_compute_forward_rope

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
#define NTKV2_MAX_POS_EMB 2048
#define NTKV2_CORRECTION_FACTOR(n_rot) (__builtin_logf(NTKV2_MAX_POS_EMB / ((n_rot) * 2 * (float)M_PI)) / 2)

// use -ffast-math so MIN and MAX are optimized to vminss and vmaxss
__attribute__((optimize("-ffast-math"), always_inline))
static inline float ntkv2_ramp(const float low, const float high, const int i0) {
static inline float ggml_rope_ntkv2_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MIN(0.001f, high - low);
return 1 - MIN(1, MAX(0, y));
}

// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static float compute_ntkv2(
static float ggml_rope_ntkv2(
float theta_base,
float theta_linear,
float theta_ntk,
float dims_over_base,
float freq_scale,
const float corr_factors[4],
int64_t i0,
float ntk_factor,
float extrapolation_factor,
int n_dims) {
float extrapolation_factor) {
float ramp_mix;
float theta;

ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;

ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
return theta;
}

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) {
static const float max_pos_emb = 2048;
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}

void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) {
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
// Do not change unless there is a good reason for doing so!
static const float BETA_0 = 1.75f;
static const float BETA_1 = 1.25f;
static const float GAMMA_0 = 16.0f;
static const float GAMMA_1 = 2.0f;

static const float low_1p = NTKV2_CORRECTION_FACTOR(BETA_0);
static const float high_1p = NTKV2_CORRECTION_FACTOR(BETA_1);
static const float low_2p = NTKV2_CORRECTION_FACTOR(GAMMA_0);
static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1);

// start and end correction factors
const float low_1 = maxf(0.0f, floorf(low_1p * dims_over_base));
const float high_1 = minf(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
const float low_2 = maxf(0.0f, floorf(low_2p * dims_over_base));
const float high_2 = minf(n_dims - 1.0f, ceilf(high_2p * dims_over_base));

const float theta_linear = freq_scale * theta_base;
float ramp_mix;
float theta;

ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;

ramp_mix = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
return theta;
factors[0] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base)));
factors[1] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base)));
factors[2] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base)));
factors[3] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base)));
}

static void ggml_compute_forward_rope_f32(
Expand Down Expand Up @@ -12201,7 +12198,8 @@ static void ggml_compute_forward_rope_f32(

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
const float dims_over_base = n_dims / logf(freq_base);
float corr_factors[4];
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
Expand Down Expand Up @@ -12243,8 +12241,9 @@ static void ggml_compute_forward_rope_f32(
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base,
freq_scale, i0, ntk_factor, extrapolation_factor, n_dims);
const float theta_linear = freq_scale * theta_base;
const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors,
i0, ntk_factor, extrapolation_factor);
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);

Expand Down Expand Up @@ -12343,7 +12342,8 @@ static void ggml_compute_forward_rope_f16(

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
const float dims_over_base = n_dims / logf(freq_base);
float corr_factors[4];
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
Expand Down Expand Up @@ -12385,8 +12385,9 @@ static void ggml_compute_forward_rope_f16(
}
} if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base,
freq_scale, i0, ntk_factor, extrapolation_factor, n_dims);
const float theta_linear = freq_scale * theta_base;
const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors,
i0, ntk_factor, extrapolation_factor);
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);

Expand Down
3 changes: 3 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,9 @@ extern "C" {
float extrapolation_factor,
int n_ctx);

// compute correction factors for NTKv2 RoPE scaling
void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
Expand Down

0 comments on commit f30c571

Please sign in to comment.