-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix : lookup word in vocab before doing BPE merges #7193
fix : lookup word in vocab before doing BPE merges #7193
Conversation
4ba2e5c
to
63207d1
Compare
This change not only fixed the llama3 tokenization, but it also improved the performance by a factor of x4: ./tests/test-tokenizer-0.sh llama-bpe ./build/wikitext-2-raw/wiki.train.raw
We now tokenize |
Which parameter in the tokenizer config determines this behaviour? |
@ggerganov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge after the green is CI
0f48f9e
to
0c9a0ae
Compare
llama.cpp
Outdated
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { | ||
llm_symbol sym; | ||
sym.text = word.c_str(); | ||
sym.n = word.size(); | ||
sym.prev = final_prev_index; | ||
sym.next = -1; | ||
if (final_prev_index != -1) { | ||
symbols_final[final_prev_index].next = symbols_final.size(); | ||
} | ||
symbols_final.emplace_back(sym); | ||
final_prev_index = symbols_final.size() - 1; | ||
continue; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's apply @jaime-m-p's suggestion here, to reduce the code duplication in this loop:
For llama-3, I found there is an inconsistency between llama.cpp's tokenizer and Huggingface's tokenizers. Example:
llama.cpp:
Huggingface's tokenizers with tokenizer.json from llama-3:
After comparing the implementation, it seems that Huggingface's tokenizers will try to lookup a split word in the vocabulary first, and push to the result tokens if found; if not, it will try to merge the word at byte level instead. In llama.cpp, we always do the byte-level merge, hence the inconsistency.
This is a simple fix to the problem, by just looking the word up before do the merging.
PS: I have checked with tiktoken and it seems they did the same thing at
src/lib.rs:228
inCoreBPE::_encode_native
PPS: I searched
tokenizer.json
from all BPE models (some are license-walled so I checked their variants) and it seems that llama-3 is the only one doing this?