Skip to content

Commit

Permalink
ggml-cuda : update rope implementation for parallel decoding (#3254)
Browse files Browse the repository at this point in the history
* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
slaren and ggerganov authored Sep 19, 2023
1 parent daf4c6d commit 7e2b997
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 25 deletions.
79 changes: 54 additions & 25 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
bool copied;
};

// this is faster on Windows
Expand Down Expand Up @@ -4355,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
}

// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {

static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (col >= ncols) {
Expand All @@ -4365,8 +4367,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c

const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;

const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
const int p = pos != nullptr ? pos[i2] : 0;
const float p0 = p * freq_scale;
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand All @@ -4377,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (col >= ncols) {
Expand All @@ -4387,8 +4392,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco

const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;

const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
const int p = pos != nullptr ? pos[i2] : 0;
const float p0 = p * freq_scale;
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand All @@ -4399,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
}

static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, const int n_ctx) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;

Expand All @@ -4410,11 +4418,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol

const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;

const float col_theta_scale = powf(theta_scale, col);
const float p = p0 + p_delta*(row/p_delta_rows);
// FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0;

const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand All @@ -4424,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;

const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta);

Expand Down Expand Up @@ -5361,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}

static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
}

static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
}

static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
GGML_ASSERT(ncols % 4 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
}

static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
Expand Down Expand Up @@ -6069,9 +6079,10 @@ inline void ggml_cuda_op_rope(

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);

const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
Expand All @@ -6082,19 +6093,37 @@ inline void ggml_cuda_op_rope(
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;

GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);

int id;
CUDA_CHECK(cudaGetDevice(&id));

int * pos = nullptr;
if ((mode & 1) == 0) {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
pos = (int *) src1_extra->data_device[id];
if (!src1_extra->copied) {
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
src1_extra->copied = true;
}
}

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

// compute
if (is_glm) {
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
GGML_ASSERT(false);
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
} else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
} else {
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
}

(void) src1;
Expand Down
6 changes: 6 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,7 @@ static struct ggml_cgraph * llm_build_llama(

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
Expand All @@ -2719,6 +2720,7 @@ static struct ggml_cgraph * llm_build_llama(
// shift the entire K-cache if needed
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
Expand Down Expand Up @@ -3092,6 +3094,7 @@ static struct ggml_cgraph * llm_build_baichaun(

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
Expand All @@ -3103,6 +3106,7 @@ static struct ggml_cgraph * llm_build_baichaun(
// shift the entire K-cache if needed
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
Expand Down Expand Up @@ -3496,6 +3500,7 @@ static struct ggml_cgraph * llm_build_falcon(

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
Expand All @@ -3507,6 +3512,7 @@ static struct ggml_cgraph * llm_build_falcon(
// shift the entire K-cache if needed
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
Expand Down

0 comments on commit 7e2b997

Please sign in to comment.