Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying p scale factor for ggml rope and rope_back ops #1967

Closed
wants to merge 4 commits into from

Conversation

KerfuffleV2
Copy link
Collaborator

@KerfuffleV2 KerfuffleV2 commented Jun 22, 2023

Based on the description in #1965

This adds ggml_rope_scaled, ggml_rope_scaled_inplace, ggml_rope_back_scaled ops. The existing rope ops just pass 1.0 to the impl for compatibility with the current behavior/API.

Add LLAMA_ROPE_SCALE to Makefile (note not in cmake yet), if not specified defaults to 1.0. You can run, for example, 'make LLAMA_ROPE_SCALE=0.5`.

I'd guess we probably don't want to use the LLAMA_ROPE_SCALE approach if this actually gets merged, it's mainly just there to facilitate easier testing.

edit: Theoretically this should work for CUDA and Metal now (untested). Note the GPU ops only apply when loading every possible layer into VRAM (for 7B that would be -ngl 35).

This adds ggml_rope_scaled, ggml_rope_scaled_inplace, ggml_rope_back_scaled ops

Add LLAMA_ROPE_SCALE to Makefile (note not in cmake yet), if not specified defaults to 1.0
Makefile Show resolved Hide resolved
@KerfuffleV2 KerfuffleV2 marked this pull request as draft June 22, 2023 14:00
Bail out if p_scale != 1.0 n rope operation for the time being
@KerfuffleV2
Copy link
Collaborator Author

Man, I'm so confused. CUDA has a rope operation and wasn't set up to handle my changes, but apparently it doesn't actually ever get called? Even just asserting false at the beginning doesn't result in failure.

@KerfuffleV2 KerfuffleV2 marked this pull request as ready for review June 22, 2023 14:19
@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

CUDA rope gets called when the tensors are loaded onto VRAM (-ngl 35), I think.

@KerfuffleV2
Copy link
Collaborator Author

CUDA rope gets called when the tensors are loaded onto VRAM

Oh, thanks. That's the kind of thing only people with a decent GPU would notice, I guess!

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

I have an old card but it can run 7B Q4_0 with its 8GB VRAM.

@KerfuffleV2
Copy link
Collaborator Author

I have an old card but it can run 7B Q4_0 with its 8GB VRAM.

Enjoying your powerful 8GB VRAM GPU, sitting back in your comfy chair, probably eating avocado toast every day. Life must be pretty sweet!

But yeah, I only have 6GB and I'm pretty sure -c 4096 + offloading 35 layers just isn't happening even at Q4_0, probably not even with Q2_K.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

probably eating avocado toast every day

Yeah, my wife is Mexican, I get a lot of nice avocado everything, all the time.

Maybe you can try the OpenLLaMA 3B? The model itself may not be amazing but getting a comparative baseline may be possible. I have some model files on HF available.

@KerfuffleV2
Copy link
Collaborator Author

Yeah, my wife is Mexican, I get a lot of nice avocado everything, all the time.

Sounds nice! I love avocados. The joke is that avocados are so expensive you have to mortgage your house to eat avocado toast. :)

Maybe you can try the OpenLLaMA 3B?

It wouldn't really help much currently, because I didn't (and don't have the knowledge to) write an actual implementation for Metal/CUDA. All I did was add an assert to error out if the scale factor isn't 1.0 and it ends up in the Metal or CUDA rope implementations.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

I don't know it seems pretty easy to edit ggml-cuda.cu, the theta calculation is only in one place.

@KerfuffleV2
Copy link
Collaborator Author

I don't know it seems pretty easy to edit ggml-cuda.cu, the theta calculation is only in one place.

The problem is I made the change just based on the example in the original discussion, not based on actually understanding the math. So I don't know how to convert it from the iterative CPU version to the parallel GPU kernel form and be sure that it's correct.

If you want to show me what it should be, I'd be happy to update the PR with the change.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

It's in rope_f32():

const float theta = p*powf(theta_scale, col/2);

I just modified it there like this:

const float theta = 0.5*p*powf(theta_scale, col/2);

And I got the pretty much the same result for perplexity as is posted (5.9839) in that thread.

Actually now that I'm looking at it p is passed down so it could be modified already before this GPU kernel function.

@JohannesGaessler
Copy link
Collaborator

Man, I'm so confused. CUDA has a rope operation and wasn't set up to handle my changes, but apparently it doesn't actually ever get called? Even just asserting false at the beginning doesn't result in failure.

As @SlyEcho said, the RoPE tensor is only evaluated on the GPU if enough layers are being offloaded, specifically if the K component of the KV cache gets offloaded (num layers + 3). This is because when I tested it the overhead from cudaDeviceSynchronize was not worth the speedup. I have since implemented better synchronization logic though so I'll maybe need to revisit that.

@KerfuffleV2
Copy link
Collaborator Author

KerfuffleV2 commented Jun 22, 2023

I was able to test this with CUDA (although only with 128 context length and 35 layers). I am a lot less sure about the Metal change. It produces reasonable results with CUDA.

(Also, thanks SlyEcho!)

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

OpenLLaMA 3B doesn't seem to like this method... Or maybe just my power function

@KerfuffleV2
Copy link
Collaborator Author

OpenLLaMA 3B doesn't seem to like this method... Or maybe just my power function

I didn't do it quite the way you said, I tried to use the passing down p approach but in my ignorance I might not have gotten it right.

https://github.com/ggerganov/llama.cpp/pull/1967/files#diff-66b17223e8ba54054fb2600ecbd31107f8b917bac36c7f3789811b0f0e9802a1L1915-R1921

That's just scaling p before it gets passed to rope_f32_cuda.

@KerfuffleV2
Copy link
Collaborator Author

@JohannesGaessler Since you already have a branch with this PR's changes (except maybe the GPU stuff which may well be broken) do you want to just take over adding that functionality?

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

That's just scaling p before it gets passed to rope_f32_cuda.

I think it's better this way.

@JohannesGaessler
Copy link
Collaborator

In the first place, I think we'll need to wait for the opinion of @ggerganov regarding the ggml changes. But I think it's fine to just have two separate PRs since the features should work independently from one another.

@KerfuffleV2
Copy link
Collaborator Author

But I think it's fine to just have two separate PRs since the features should work independently from one another.

Just checking. I'd generally agree, but in this case I'm playing with stuff I don't fully understand (and in the case of Metal can't even test at all) so I thought I'd offer to let someone competent take over.

I think it's better this way.

Didn't you say it doesn't work, though. Or am I misunderstanding what "OpenLLaMA 3B doesn't seem to like this method..." meant?

assert(ggml_nelements(src1) == 4);
const int n_past = (int)((float *) src1->data)[0];
const int n_dims = (int)((float *) src1->data)[1];
const int mode = (int)((float *) src1->data)[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are UB.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are UB.

Can you explain in more detail? Are you saying it's illegal to cast float to int in CUDA?

Also, just to be clear, those values were originally ints that were cast to float so we know they started out as integer values. Also, none of them should be out of the range of a 32 bit float's ability to represent integers (should be ~16mil for signed).

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 22, 2023

Didn't you say it doesn't work, though. Or am I misunderstanding what "OpenLLaMA 3B doesn't seem to like this method..." meant?

The model just doesn't seem to support the context scaling well, same for 7B.

@KerfuffleV2
Copy link
Collaborator Author

The model just doesn't seem to support the context scaling well, same for 7B.

It seems to work somewhat better on larger models. The small models are probably just barely clinging to sanity in the first place, stuff like this may be enough to push them over the edge.

Another thing that might help is using a good sampler and parameters that help keep the model on track, like Mirostat.

@Midaychi
Copy link

As far as I understood it, this technique was designed to take advantage of the over fit trained into the original llama models. It's unclear if this applies to open-llama as well. Theoretically it should, but in practice it's possible it might not

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are suggestions to take into account if we decide to merge the rope scaling option

But, for now I think we should wait it out and see how applicable this thing is, because we don't want to add and support an option that will potentially never be used

Comment on lines +6699 to +6702
((float *) b->data)[0] = (float)n_past;
((float *) b->data)[1] = (float)n_dims;
((float *) b->data)[2] = (float)mode;
((float *) b->data)[3] = p_scale;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use memcpy to store the params so we can all sleep well knowing this is not UB :)

int n_past,
int n_dims,
int mode,
float p_scale);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to extend the API - add p_scale to original ggml_rope_xxx() and add comment to use p_scale == 1.0f for regular computation. Add GGML_ASSERT(p_scale == 1.0f) in backward call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to extend the API - add p_scale to original ggml_rope_xxx()

Won't this break every single thing that currently uses the llama.cpp version of GGML?

What do you think about using a define to enable the p_scale argument for rope and having it be off by default? That way existing stuff can opt in.

It might also be worth thinking about adding something like GGML_API_VERSION which could be bumped when incompatible changes occur, so stuff building against GGML could handle API changes more gracefully.

@KerfuffleV2
Copy link
Collaborator Author

KerfuffleV2 commented Jun 24, 2023

Unfortunately my system (and internet) recently got fried by a lightning strike. I don't have the ability to make changes to this PR at the moment.

If anyone wants to take over and make whatever changes are necessary (if we come to that point) then I'm fine with it. It may be several days before I have a functional system again.

for now I think we should wait it out and see how applicable this thing is

That was my thinking too, which is why I didn't do stuff like add commandline options. I was trying to keep my changes pretty minimal and limited to the basic functionality.

@ikawrakow
Copy link
Contributor

The issue with linearly scaling the token position is that it reduces accuracy for smaller context windows quite a bit. E.g., if I scale with 0.5 and then have a context of 512, I get perplexity of 7.0012 for LLaMA 7B and context size of 512 using Q4_0. This is to be compared to 5.9066 for the fp16 model, or 6.1565 for Q4_0 and no scaling.

I think, it would be better to have a function that is nearly linear for positions near the beginning of the context window, and then gracefully slows down as the position increases. Here is an example that behaves pretty well for context sizes of up to 3584: In function ggml_cuda_op_rope() change to

const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
const float p = p0/(1 + p0/6144.f);

Using this and Q6_K quantization (to avoid wondering to what extent the difference to fp16 is due to the modification of the position vs the non-negligible quantization error of Q4_0), I get

Context Modified RoPE (non-linear) Modified RoPE (linear) Original RoPE, fp16
512 5.9195 6.8045 5.9066
1024 5.4609 6.1476 5.4305
2048 5.3974 5.9208 5.2810
3072 5.3883 5.7757 -
3584 5.8047 5.7387 -

So, basically, we can extent the context window to 3072 without much loss in accuracy. The p = p0/(1 + p0/6144) function gives p = 2048 at 3072 and p = 2264 at 3584, so interesting to see that one can "step out" of the 2048 training context window by some margin without perplexity going to infinity. Another function that works quite well up to a context of 3072 is

p = 2048.f * atanf(p0 / 2048);

If one wants to get to a context window of 4096, these are examples of mappings that work better than p = 0.5f * p0:

p = 2048.f * 4.f / 3.14159.f * atanf(p0 / 4096);  // ppl(4096) = 5.7304
p = 5050.f * logf(1 + p0/8192.f);                 // ppl(4096) = 5.7145

They are slightly better than linear scaling at 4096, and quite a bit better for smaller context windows.

So, overall, I think that it should be possible to pick a position transformation that works better than linear scaling that depends on the requested context size. Ideally, this should be done automatically instead of asking the user to recompile llama.cpp if they want to use a context larger than 2048 (and then forgetting about and wondering why their results seem worse with a context < 2048).

I did some experimentation for contexts beyond 4096, but there things get flaky (hitting asserts for scratch buffers not big enough, CUDA randomly giving NaNs with the exact same position mapping working fine on the CPU, etc.), so I think we need to also look into making things more stable in that regime.

@SlyEcho
Copy link
Collaborator

SlyEcho commented Jun 25, 2023

beyond 4096

Also runs into memory limits...

@kaiokendev
Copy link

The issue with linearly scaling the token position is that it reduces accuracy for smaller context windows quite a bit. E.g., if I scale with 0.5 and then have a context of 512, I get perplexity of 7.0012 for LLaMA 7B and context size of 512 using Q4_0. This is to be compared to 5.9066 for the fp16 model, or 6.1565 for Q4_0 and no scaling.

Hello
Please understand that the scaling is intended to be used with fine-tuning. I mentioned this several times that it is not a "free" context extension for pre-trained models. The intention is to fine-tune the model with the scale and treat it as a hyperparameter. In any case, I have performed perplexity evaluation against base LLaMa 13, SuperHOT 13B (4096 sequence length, with 0.25 scale during training) and SuperHOT 13B (4096 sequence length, with no scale during training). The results should clarify that with finetuning, the scaled model is able to beat even the unscaled models after fine-tuning. Please do keep that in mind: you will not see the actual benefits unless you fine-tune the model.

ppl
https://files.catbox.moe/7q75zx.json

@ikawrakow
Copy link
Contributor

@kaiokendev Sorry if my comment came across as being critical. Coming up with the idea of scaling the positions is really great (it would have never occurred to me to try that), and I totally get it that one can improve via fine-tuning. But the point of my comment was to try to go beyond the original idea. From the experiments I have done, it looks like one can get away without fine tuning for contexts up to 4096 using simple non-linear position mappings. Or, if one decided to fine-tune anyway, my uneducated guess would be that one would end up with a better final result having started with a better initial guess. My initial post only showed perplexities for LLaMA-7B. Looking at 13B, I get perplexity of 4.8 at a context size of 3500, which seems much better than the graph and the data in the linked json, no?. I cannot easily go beyond 3500 because the llama.cpp GPU code stops functioning (gives NaNs), and it is too slow to do such investigations on the CPU.

@kaiokendev
Copy link

kaiokendev commented Jun 26, 2023

@ikawrakow no I am not taking the comment critically lol I am really following everyone who is taking it to the limits and its really exciting. Am just reminding that the attention heads are not calibrated to the interpolated positions, so it is doubtful the model is fully leveraging those interpolated positions. Maybe it is the case up to a point, but the other side is that secretly it is ditching certain concepts from the interpolated portions. Im not meaning to stop the work in fact if you can show it is extrapolating even without finetuning and is really using the 3072 then it has a lot of value, many people will be blown away by that discovery since it is thought the untrained model cannot extrapolate at all. The value in the chart is higher because I use 4-bit precision models.

@JohannesGaessler
Copy link
Collaborator

I cannot easily go beyond 3500 because the llama.cpp GPU code stops functioning (gives NaNs), and it is too slow to do such investigations on the CPU.

I'll look into it in the next few days.

@FNsi
Copy link
Contributor

FNsi commented Jun 27, 2023

Set to 0.125 with 16k context Lora, did have a good result...

Sorry, that won't happen in openblas.

I rebuild with Blas and run shit..

@KerfuffleV2
Copy link
Collaborator Author

Closing in favor of #2019 which is probably a better approach.

@KerfuffleV2 KerfuffleV2 deleted the feat-rope_scaled branch September 6, 2023 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants