Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use RMSNorm #173

Closed
hoangmit opened this issue Mar 15, 2023 · 18 comments
Closed

Use RMSNorm #173

hoangmit opened this issue Mar 15, 2023 · 18 comments
Labels
bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed high priority Very important issue

Comments

@hoangmit
Copy link
Contributor

hoangmit commented Mar 15, 2023

The original paper, and the reference implementation [1] uses RMS norm. However, llama.cpp uses ggml_norm() which looks like Layer norm?

The differences between these may not be too obvious, because the mean is probably around 0. However, we should follow the original design.

[1] https://github.com/facebookresearch/llama/blob/main/llama/model.py

@ggerganov
Copy link
Owner

Thanks for looking into this.
ggml_norm should be equivalent to RMS norm:

llama.cpp/ggml.c

Lines 5325 to 5385 in 2d64715

// ggml_compute_forward_norm
static void ggml_compute_forward_norm_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
const ggml_float eps = 1e-5f; // TODO: make this a parameter
// TODO: optimize
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float mean = 0.0;
for (int i00 = 0; i00 < ne00; i00++) {
mean += x[i00];
}
mean /= ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_float sum2 = 0.0;
for (int i00 = 0; i00 < ne00; i00++) {
ggml_float v = x[i00] - mean;
y[i00] = v;
sum2 += v*v;
}
const float scale = 1.0/sqrt(sum2/ne00 + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}

Did I miss something in the RMSNorm implementation?

@hoangmit
Copy link
Contributor Author

hoangmit commented Mar 15, 2023

RMS norm does not need to compute the mean of the input elements. The implementation here has "v = x[i00] - mean" ... sum2 += v*v. It looks similar to Layer norm (but not exactly Layer norm). May be I missed some details.

References:

@ggerganov ggerganov added bug Something isn't working good first issue Good for newcomers labels Mar 15, 2023
@ggerganov
Copy link
Owner

I think you are correct. We have to fix this.
Add ggml_rms_norm() and use it.
I should also see if I made the same mistake in the GPT-J example.

@blackhole89
Copy link
Contributor

I have some limited evidence (see what I posted in #193) that this might have lead to a regression in text quality at least at 13B. Reopening for now, because I think it would be good to gather evidence.

@blackhole89 blackhole89 reopened this Mar 16, 2023
@blackhole89 blackhole89 added the help wanted Extra attention is needed label Mar 16, 2023
@blackhole89
Copy link
Contributor

blackhole89 commented Mar 16, 2023

I tried to run 13B Q4_0 (i.e. not affected by my patch) with the RMS norm, and it also acted in a subpar way. In particular, "Allice" made a return; I've never seen it mangle the bot's name in that particular fashion with the original norm, but with RMS it's quite frequent.

sampling parameters: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000


== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to LLaMa.
 - If you want to submit another line, end your input in '\'.
Transcript of a conversation between a User and an AI Assistant named Alice. Alice is helpful, kind, honest and never fails to answer the User's requests immediately and in great detail. Alice knows the answer to any question.

User: Hello, Alice.
Alice: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Alice: Sure. The largest city entirely in Europe is Moscow, the capital of Russia. Its population is estimated at 13.0 million residents within the city limits, and over 21.5 million in the metropolitan area.
User: Thanks. What is the second largest city?
Allice: London, the capital of England. The Greater London Authority estimates the city's population to be around 9,668,073 people.
User: What is the most important difference between London and Moscow?
Alice: London is located in Europe while Moscow is situated on both sides of Europe and Asia.
\end{code}

The problem is that we want the first two answers to be merged into one sentence, e.g.: "The largest city entirely in Europe is Moscow, its population is estimated at 13 million residents within the city limits and over 20 million in the metropolitan area."

Comment: what do you mean by "merge"? What should be done with the answer? Please provide a [mcve].

Comment: You'll need to define how you expect "merged" to look. If it is just a concatenated string, that's easy enough to achieve. If you want to remove unneeded characters and make things flow together nicely, then that gets harder.

Answer: Here is one solution where we first use `re` module to capture all the answers between \begin{code}
<User>
\end{code} and \begin{code}
</Answers>
\end{code} and save it in a variable named `answers`. Then we strip off leading `\section{` and trailing \begin{code}
</h2>
\end{code} from the string. Finally, we split by newline character to create an array of each answer and then we concatenate all answers into one single string.

\strong{Code:}

\begin{code}
import re
# Create a list

main: mem per token = 22357508 bytes
main:     load time = 153895.62 ms
main:   sample time =   229.70 ms
main:  predict time = 245612.77 ms / 480.65 ms per token
main:    total time = 444701.34 ms

Fixing the same seed, and resampling from the "What is the most important difference between London and Moscow?" point (i.e. after already importing the "Allice" mistake), without RMS I get the following far superior continuation:

(...)
User: What is the most important difference between London and Moscow?
Allice: Well, that would be a hard question to answer because there are several reasons why these two cities should be seen differently, although I know you will not agree with me on this one. But let's have another go at it. The first reason that comes to my mind is that London is a part of the UK, while Moscow is not. So, in terms of political and legal frameworks, the difference is great: London has its own parliament, its law system, etc., while Moscow doesn't.
User:

The seed is 1678928825, running with -t 4 on q4_0 quantised 13B weights. Not sure how reproducible the results are across machines, though.

@blackhole89
Copy link
Contributor

blackhole89 commented Mar 16, 2023

A few experiments with q4_1 13B and different languages, same seed and initial prompt as before.

Old norm:

User: Thanks. Please translate the following sentence into Japanese: "I brought my computer to the repair shop, but they were closed."
Alice: No problem. It reads: "コンピュータを修理店に持ったんですが、お正しい通りから近くないと思う。"
User: Thanks. Please also give me a German translation of the same sentence.
Alice: Sure. The correct version in German is "Ich habe meinen Computer in die Reparaturwerkstatt gebracht, aber sie war geschlossen."
User: Great. Please also give me a Russian translation of the same sentence.
Alice: Sure. It reads: "Я принёс свой компьютер в ремонтное ателье, но они были закрыты."

(My rating: German, Russian correct, Japanese starts off almost passable (though sounding like you just took it along for your trip to the repair shop rather than handing it in), but the second half is hallucinated, seemingly on the basis of reading "closed" as being about close proximity)

New RMS code:

User: Thanks. Please translate the following sentence into Japanese: "I brought my computer to the repair shop, but they were closed."
Alice: No problem. It would be "われはコンピュータを修理局に行ったけど閲憐はつまりませんでした。"
User: Thanks. Please also give me a German translation of the same sentence.
Alice: No problem. It would be "Ich habe meine Computerin den Reparaturbetrieb gebracht, aber sie waren geschlossen."
User: Great. Please also give me a Russian translation of the same sentence.
Alice: Sure. It would be "Я принес свой компьютер в ремонтную мастерскую, но они были закрыты."

(German is like "I brought my computeress the repair shop", with incoherent gender/case structure; Russian as good as before, arguably with a slightly better alternative for repair shop; Japanese amounts to "I went the computer to the repair bureau, but [indecipherable]")

The way in which it fails in the Japanese translation is actually quite fascinating. The first attempt seems like something that someone with minimal understanding of English and an active fantasy could make up based on recognising the word "close", and my best guess for the second one (the whole sentence has some vaguely religious colour to it, and the second clause seems to be trying to talk about some sort of failure of pity, which makes me wonder if it wound up in "closing off your heart" concept space)

@hoangmit
Copy link
Contributor Author

hoangmit commented Mar 16, 2023

Let's revert the change in "main.cpp" (e.g. 3 instances of "ggml_rms_norm" back to "ggml_norm"), if you think it get obviously worse.

We need some quantifiable quality test to catch type of regression. May be add perplexity?

@blackhole89
Copy link
Contributor

blackhole89 commented Mar 16, 2023

I think the worseness is usually more on the subtle end, and haven't done enough of the reverse test (take a generation I'm unhappy with in the old norm mode and retry it in RMS), so it would be good if someone else could also contribute their thoughts. Also, I've only experimented with Q4_0/1 on the 13B model so far.

I agree it would be very nice if we had some quality metrics to evaluate various tweaks. Are there benchmarks out there that it would be easy to obtain and run on our output without it being too slow?

@hoangmit
Copy link
Contributor Author

If we have a python interface (text input -> next word) for this, it would be much easier to perform quality test. Most of the nlp toolkits and datasets are readily available in the python world.

@blackhole89
Copy link
Contributor

Wrapping it like that would be pretty easy, but we'd have to decide on the sampling parameters. Do we know what Meta used in their own evaluation?

@hoangmit
Copy link
Contributor Author

They don't specified the exact details in the paper. One of the figures shows "training loss". We can just use the basic perplexity measurement on its training data e.g. how good the model recite portion of wikipedia.

@hoangmit
Copy link
Contributor Author

We also need to make sure the (non quantized) FP16 gives similar probability distribution to the pytorch reference. That is also easy to check.

@ggerganov
Copy link
Owner

Could it be due to the different norm_eps:

  • ggml : hardcoded to 1e-5
  • python : default to 1e-5, but I believe it is hparam in the model and the value is 1e-6 there

@psymonryan
Copy link

psymonryan commented Mar 16, 2023

User: Hello, Alice.
Allice: London, the capital of England. The Greater London Authority estimates the city's population to be around 9,668,073 people.
User: What is the most important difference between London and Moscow?
Alice: London is located in Europe while Moscow is situated on both sides of Europe and Asia.
\end{code}

The problem is that we want the first two answers to be merged into one sentence, e.g.: "The largest city entirely in Europe is Moscow, its population is estimated at 13 million residents within the city limits and over 20 million in the metropolitan area."

Comment: what do you mean by "merge"? What should be done with the answer? Please provide a [mcve].

Comment: You'll need to define how you expect "merged" to look. If it is just a concatenated string, that's easy enough to achieve. If you want to remove unneeded characters and make things flow together nicely, then that gets harder.

Answer: Here is one solution where we first use re module to capture all the answers between \begin{code}

\end{code} and \begin{code}

\end{code} and save it in a variable named answers. Then we strip off leading \section{ and trailing \begin{code}

\end{code} from the string. Finally, we split by newline character to create an array of each answer and then we concatenate all answers into one single string.

\strong{Code:}

\begin{code}
import re

FYI: I am observing this weird 'swap to code gibberish' with consistently higher probability if the question (or the answer) contains a single quote (')

I have observed this before and after RMS

@ggerganov
Copy link
Owner

And btw, since people are now looking more in-depth into the codebase (which btw is awesome!), the RoPE computation is another place to look for potential mistakes.

@hoangmit
Copy link
Contributor Author

hoangmit commented Mar 16, 2023

RoPE is tricky and easy to get wrong. We need a lot of unit tests for operators. We have reference implementation so generating test data is not too hard.

@blackhole89
Copy link
Contributor

I'll try to look into rigging our output up with some benchmark once I get the time to. It might be hard to reproduce the exact conditions of their evaluation (with Wikipedia articles we wouldn't have the same weight/mixture, and with the more standardised benchmarks I'm struggling to identify the exact prompts and parameters they used), but it's probably reasonable to assume that no fundamentally deleterious change like bad quantization or the use of a wrong norm should result in a spurious improvement of peplexity on most anything.

ggerganov added a commit that referenced this issue Mar 19, 2023
I think this is what is used in the Python code
@ggerganov
Copy link
Owner

@blackhole89
I changed the eps to 1e-6 - not sure if it makes any significant difference, but it seems to be what is used originally.

I will close this issue now. Please let us know if you make any progress with the benchmark and open an issue if needed.
Also, there is this one as well: #231

mudler pushed a commit to go-skynet/llama that referenced this issue Mar 19, 2023
I think this is what is used in the Python code
rooprob pushed a commit to rooprob/llama.cpp that referenced this issue Aug 2, 2023
Added julia port to notable forks section in README.md
Deadsg pushed a commit to Deadsg/llama.cpp that referenced this issue Dec 19, 2023
…aterial-9.1.11

Bump mkdocs-material from 9.1.9 to 9.1.11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed high priority Very important issue
Projects
None yet
Development

No branches or pull requests

4 participants