Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying p scale factor for ggml rope and rope_back ops #1967

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ ifndef LLAMA_NO_K_QUANTS
OBJS += k_quants.o
endif

ifdef LLAMA_ROPE_SCALE
CXXFLAGS += -DLLAMA_ROPE_SCALE=$(LLAMA_ROPE_SCALE)
endif

ifndef LLAMA_NO_ACCELERATE
# Mac M1 - include Accelerate framework.
# `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
Expand Down
11 changes: 7 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1906,13 +1906,16 @@ inline void ggml_cuda_op_rope(
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
assert(src1->type == GGML_TYPE_F32);
assert(ggml_nelements(src1) == 4);
const int n_past = (int)((float *) src1->data)[0];
const int n_dims = (int)((float *) src1->data)[1];
const int mode = (int)((float *) src1->data)[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are UB.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are UB.

Can you explain in more detail? Are you saying it's illegal to cast float to int in CUDA?

Also, just to be clear, those values were originally ints that were cast to float so we know they started out as integer values. Also, none of them should be out of the range of a 32 bit float's ability to represent integers (should be ~16mil for signed).

const float p_scale = ((float *) src1->data)[3];
GGML_ASSERT(mode == 0);

const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
const float p = p_scale * ((mode & 1) == 0 ? n_past + i02 : i02);

// compute
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
Expand Down
49 changes: 26 additions & 23 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -861,33 +861,36 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];

const int n_past = ((int32_t *)(src1->data))[0];
assert(src1->type == GGML_TYPE_F32);
assert(ggml_nelements(src1) == 4);
const int n_past = (int)((float *) src1->data)[0];
const int n_dims = (int)((float *) src1->data)[1];
const int mode = (int)((float *) src1->data)[2];
const float p_scale = ((float *) src1->data)[3];

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder setBytes:&p_scale length:sizeof( float) atIndex:21];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
Expand Down
3 changes: 2 additions & 1 deletion ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & p_scale,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
Expand All @@ -625,7 +626,7 @@ kernel void kernel_rope(

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);

float theta = (float)p;
float theta = p_scale * (float)p;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
Expand Down
Loading