Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda : update rope implementation for parallel decoding #3254

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
bool copied;
};

// this is faster on Windows
Expand Down Expand Up @@ -4355,7 +4356,15 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
}

// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
int p = pos[i];
p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale;
}
}

static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

Expand All @@ -4365,8 +4374,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c

const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;

const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand All @@ -4377,7 +4387,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0,
const float p_delta, const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

Expand All @@ -4387,8 +4397,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco

const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;

const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand All @@ -4399,7 +4410,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
}

static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0,
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;
Expand All @@ -4410,9 +4421,10 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol

const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;

const float col_theta_scale = powf(theta_scale, col);
const float p = p0 + p_delta*(row/p_delta_rows);
const float p = p0[i2] + p_delta*i2;

const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
const float sin_theta = sinf(theta);
Expand Down Expand Up @@ -5361,7 +5373,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}

static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
Expand All @@ -5370,7 +5382,7 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}

static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
Expand All @@ -5379,7 +5391,7 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}

static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
GGML_ASSERT(ncols % 4 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
Expand Down Expand Up @@ -6069,9 +6081,10 @@ inline void ggml_cuda_op_rope(

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);

const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
Expand All @@ -6082,21 +6095,40 @@ inline void ggml_cuda_op_rope(
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
//const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;

GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);

int id;
CUDA_CHECK(cudaGetDevice(&id));

struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
if (!src1_extra->copied) {
CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
src1_extra->copied = true;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very good, is there a better way to sync the tensor to VRAM?


size_t p0d_as = 0;
float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

// compute
if (is_glm) {
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream);
} else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream);
} else {
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream);
}

ggml_cuda_pool_free(p0d, p0d_as);

(void) src1;
(void) dst;
(void) src1_dd;
Expand Down
6 changes: 6 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,7 @@ static struct ggml_cgraph * llm_build_llama(

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
Expand All @@ -2715,6 +2716,7 @@ static struct ggml_cgraph * llm_build_llama(

// K_shift
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
Expand Down Expand Up @@ -3087,6 +3089,7 @@ static struct ggml_cgraph * llm_build_baichaun(

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
Expand All @@ -3097,6 +3100,7 @@ static struct ggml_cgraph * llm_build_baichaun(

// K_shift
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
Expand Down Expand Up @@ -3486,6 +3490,7 @@ static struct ggml_cgraph * llm_build_falcon(

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
Expand All @@ -3496,6 +3501,7 @@ static struct ggml_cgraph * llm_build_falcon(

// K_shift
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
Expand Down