Skip to content

Commit

Permalink
llama: improve NTKv2 CUDA implementation
Browse files Browse the repository at this point in the history
Precompute what we can on the host to make the device kernel smaller,
and to avoid magic constants.
  • Loading branch information
cebtenzzre committed Jul 19, 2023
1 parent 2a9ba48 commit f3b9eae
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 52 deletions.
94 changes: 50 additions & 44 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1875,52 +1875,36 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}

static __device__ void ntkv2_ramp(const float low, const float high, const int i0, float *out) {
static __device__ float ntkv2_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
*out = 1.0f - min(1.0f, max(0.0f, y));
return 1.0f - min(1.0f, max(0.0f, y));
}

// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ void compute_ntkv2(
static __device__ float compute_ntkv2(
float theta_base,
float theta_linear,
float theta_ntk,
float dims_over_base,
float freq_scale,
const float corr_factors[4],
int64_t i0,
float ntk_factor,
float extrapolation_factor,
int n_dims,
float *theta) {
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
// Do not change unless there is a good reason for doing so!
// These are precomputed because CUDA doesn't allow dynamic init of device constants
static const float low_1p = 2.6135630f;
static const float high_1p = 2.7817991f;
static const float low_2p = 1.5070765f;
static const float high_2p = 2.5467973f;

// start and end correction factors
const float low_1 = max(0.0f, floorf(low_1p * dims_over_base));
const float high_1 = min(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
const float low_2 = max(0.0f, floorf(low_2p * dims_over_base));
const float high_2 = min(n_dims - 1.0f, ceilf(high_2p * dims_over_base));

float extrapolation_factor) {
float ramp_mix;
float theta;

const float theta_linear = freq_scale * theta_base;
ntkv2_ramp(low_1, high_1, i0, &ramp_mix);
ramp_mix *= ntk_factor;
const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
ntkv2_ramp(low_2, high_2, i0, &ramp_mix);
ramp_mix *= extrapolation_factor;
*theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
ramp_mix = ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;

ramp_mix = ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
return theta;
}

// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
static __global__ void rope_f32(const float * x, float * dst, const int ncols,
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
const float theta_ntk_scale, const float dims_over_base, const float p) {
const float theta_ntk_scale, const float p, const float corr_factors[4]) {

const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);

Expand All @@ -1931,11 +1915,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;

const float theta_base = p*powf(theta_scale, col/2);
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
float theta;
compute_ntkv2(theta_base, theta_ntk, dims_over_base,
freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta);
const float theta_base = p*powf(theta_scale, col/2);
const float theta_linear = freq_scale * theta_base;
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
const float theta = compute_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor,
extrapolation_factor);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand Down Expand Up @@ -2415,16 +2399,16 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
}

static void rope_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
const float * x, float * dst, const int ncols, const int nrows,
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
const float theta_ntk_scale, const float dims_over_base, const float p, cudaStream_t stream) {
const float theta_ntk_scale, const float p, const float corr_factors[4], cudaStream_t stream) {

GGML_ASSERT(nrows % 2 == 0);
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors);
}

static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
Expand Down Expand Up @@ -2990,6 +2974,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
(void) i1;
}

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ntkv2_correction_factor(const int n_dims, const float n_rot, const float base) {
static const float max_pos_emb = 2048;
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}

inline void ggml_cuda_op_rope(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
Expand All @@ -3016,8 +3007,6 @@ inline void ggml_cuda_op_rope(
memcpy(&extrapolation_factor, (int32_t *) src1->data + 7, sizeof(float));

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
const float dims_over_base = n_dims / logf(freq_base);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);

bool is_glm = mode & 4;
Expand All @@ -3028,8 +3017,25 @@ inline void ggml_cuda_op_rope(
const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else {
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);

// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
// Do not change unless there is a good reason for doing so!
static const float BETA_0 = 1.75f;
static const float BETA_1 = 1.25f;
static const float GAMMA_0 = 16.0f;
static const float GAMMA_1 = 2.0f;

// start and end correction factors
const float corr_factors[4] = {
max(0.0f, floorf(ntkv2_correction_factor(n_dims, BETA_0, freq_base))),
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, BETA_1, freq_base))),
max(0.0f, floorf(ntkv2_correction_factor(n_dims, GAMMA_0, freq_base))),
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, GAMMA_1, freq_base))),
};

rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main);
}

(void) dst;
Expand Down
21 changes: 13 additions & 8 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12129,16 +12129,21 @@ static float compute_ntkv2(
static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1);

// start and end correction factors
const float low_1 = maxf(0, floorf(low_1p * dims_over_base));
const float high_1 = minf(n_dims - 1, ceilf(high_1p * dims_over_base));
const float low_2 = maxf(0, floorf(low_2p * dims_over_base));
const float high_2 = minf(n_dims - 1, ceilf(high_2p * dims_over_base));
const float low_1 = maxf(0.0f, floorf(low_1p * dims_over_base));
const float high_1 = minf(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
const float low_2 = maxf(0.0f, floorf(low_2p * dims_over_base));
const float high_2 = minf(n_dims - 1.0f, ceilf(high_2p * dims_over_base));

const float theta_linear = freq_scale * theta_base;
const float ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
const float ramp_final = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
return theta_mix * (1 - ramp_final) + theta_base * ramp_final;
float ramp_mix;
float theta;

ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;

ramp_mix = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
return theta;
}

static void ggml_compute_forward_rope_f32(
Expand Down

0 comments on commit f3b9eae

Please sign in to comment.