-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml-cuda : update rope implementation for parallel decoding #3254
Changes from all commits
eec6b66
fb92acd
cbe2bac
aa18b93
9335276
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -439,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt | |
struct ggml_tensor_extra_gpu { | ||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors | ||
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs | ||
bool copied; | ||
}; | ||
|
||
// this is faster on Windows | ||
|
@@ -4355,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, | |
} | ||
|
||
// rope == RoPE == rotary positional embedding | ||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, | ||
const float p_delta, const int p_delta_rows, const float theta_scale) { | ||
|
||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, | ||
const int p_delta_rows, const float theta_scale) { | ||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||
|
||
if (col >= ncols) { | ||
|
@@ -4365,8 +4367,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c | |
|
||
const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||
const int i = row*ncols + col; | ||
const int i2 = row/p_delta_rows; | ||
|
||
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); | ||
const int p = pos != nullptr ? pos[i2] : 0; | ||
const float p0 = p * freq_scale; | ||
const float theta = p0*powf(theta_scale, col/2); | ||
const float sin_theta = sinf(theta); | ||
const float cos_theta = cosf(theta); | ||
|
||
|
@@ -4377,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c | |
dst[i + 1] = x0*sin_theta + x1*cos_theta; | ||
} | ||
|
||
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, | ||
const float p_delta, const int p_delta_rows, const float theta_scale) { | ||
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, | ||
const int p_delta_rows, const float theta_scale) { | ||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||
|
||
if (col >= ncols) { | ||
|
@@ -4387,8 +4392,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco | |
|
||
const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||
const int i = row*ncols + col/2; | ||
const int i2 = row/p_delta_rows; | ||
|
||
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); | ||
const int p = pos != nullptr ? pos[i2] : 0; | ||
const float p0 = p * freq_scale; | ||
const float theta = p0*powf(theta_scale, col/2); | ||
const float sin_theta = sinf(theta); | ||
const float cos_theta = cosf(theta); | ||
|
||
|
@@ -4399,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco | |
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; | ||
} | ||
|
||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0, | ||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) { | ||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, | ||
const int p_delta_rows, const float theta_scale, const int n_ctx) { | ||
const int col = blockDim.x*blockIdx.x + threadIdx.x; | ||
const int half_n_dims = ncols/4; | ||
|
||
|
@@ -4410,11 +4418,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol | |
|
||
const int row = blockDim.y*blockIdx.y + threadIdx.y; | ||
const int i = row*ncols + col; | ||
const int i2 = row/p_delta_rows; | ||
|
||
const float col_theta_scale = powf(theta_scale, col); | ||
const float p = p0 + p_delta*(row/p_delta_rows); | ||
// FIXME: this is likely wrong | ||
const int p = pos != nullptr ? pos[i2] : 0; | ||
|
||
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; | ||
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; | ||
const float sin_theta = sinf(theta); | ||
const float cos_theta = cosf(theta); | ||
|
||
|
@@ -4424,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol | |
dst[i + 0] = x0*cos_theta - x1*sin_theta; | ||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; | ||
|
||
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale; | ||
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; | ||
const float sin_block_theta = sinf(block_theta); | ||
const float cos_block_theta = cosf(block_theta); | ||
|
||
|
@@ -5361,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons | |
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); | ||
} | ||
|
||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, | ||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||
GGML_ASSERT(ncols % 2 == 0); | ||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); | ||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||
const dim3 block_nums(nrows, num_blocks_x, 1); | ||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); | ||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); | ||
} | ||
|
||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, | ||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||
GGML_ASSERT(ncols % 2 == 0); | ||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); | ||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||
const dim3 block_nums(nrows, num_blocks_x, 1); | ||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); | ||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); | ||
} | ||
|
||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { | ||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, | ||
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { | ||
GGML_ASSERT(ncols % 4 == 0); | ||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); | ||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; | ||
const dim3 block_nums(num_blocks_x, nrows, 1); | ||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx); | ||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); | ||
} | ||
|
||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, | ||
|
@@ -6069,9 +6079,10 @@ inline void ggml_cuda_op_rope( | |
|
||
const int64_t ne00 = src0->ne[0]; | ||
const int64_t ne01 = src0->ne[1]; | ||
const int64_t ne2 = dst->ne[2]; | ||
const int64_t nrows = ggml_nrows(src0); | ||
|
||
const int n_past = ((int32_t *) dst->op_params)[0]; | ||
//const int n_past = ((int32_t *) dst->op_params)[0]; | ||
const int n_dims = ((int32_t *) dst->op_params)[1]; | ||
const int mode = ((int32_t *) dst->op_params)[2]; | ||
const int n_ctx = ((int32_t *) dst->op_params)[3]; | ||
|
@@ -6082,19 +6093,37 @@ inline void ggml_cuda_op_rope( | |
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||
|
||
const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; | ||
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; | ||
|
||
GGML_ASSERT(src1->type == GGML_TYPE_I32); | ||
GGML_ASSERT(src1->ne[0] == ne2); | ||
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); | ||
|
||
int id; | ||
CUDA_CHECK(cudaGetDevice(&id)); | ||
|
||
int * pos = nullptr; | ||
if ((mode & 1) == 0) { | ||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; | ||
pos = (int *) src1_extra->data_device[id]; | ||
if (!src1_extra->copied) { | ||
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); | ||
src1_extra->copied = true; | ||
} | ||
} | ||
Comment on lines
+6105
to
+6113
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the current codebase I don't think there's much you can do to avoid this. The codebase currently covers constant data being copied to VRAM only before the eval directly from the model file. In all other cases the data is written to VRAM as the output of a tensor. You could of course just not offload |
||
|
||
const bool is_neox = mode & 2; | ||
const bool is_glm = mode & 4; | ||
|
||
// compute | ||
if (is_glm) { | ||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream); | ||
GGML_ASSERT(false); | ||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); | ||
} else if (is_neox) { | ||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); | ||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); | ||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); | ||
} else { | ||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); | ||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); | ||
} | ||
|
||
(void) src1; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of a conditional statement I think it would be faster to either pass zerod memory or to do the check via a template. In the latter case you could also simplify this code since
p == 0
impliessin_theta == 0
andcos_theta == 1
.