Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizer WPM fixes for bert-bge and jina-v2-en #7500

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 62 additions & 160 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2086,7 +2086,7 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;

std::unordered_map<token, id> special_tokens_cache;
std::vector<id> special_tokens_cache;

std::map<std::pair<std::string, std::string>, int> bpe_ranks;

Expand Down Expand Up @@ -4724,97 +4724,19 @@ static void llm_load_vocab(

// build special tokens cache
{
// TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
// and will always be correctly labeled in 'added_tokens.json' etc.
// The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
// to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
// are special tokens.
// From testing, this appears to correlate 1:1 with special tokens.
//

// Counting special tokens and verifying in only one direction
// is sufficient to detect difference in those two sets.
//
uint32_t special_tokens_count_by_type = 0;
uint32_t special_tokens_count_from_verification = 0;

bool special_tokens_definition_mismatch = false;

for (const auto & t : vocab.token_to_id) {
const auto & token = t.first;
const auto & id = t.second;

// Count all non-normal tokens in the vocab while iterating
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
special_tokens_count_by_type++;
vocab.special_tokens_cache.push_back(id);
}
}

// Skip single character tokens
if (token.length() > 1) {
bool is_tokenizable = false;

// Split token string representation in two, in all possible ways
// and check if both halves can be matched to a valid token
for (unsigned i = 1; i < token.length();) {
const auto left = token.substr(0, i);
const auto right = token.substr(i);

// check if we didnt partition in the middle of a utf sequence
auto utf = utf8_len(left.at(left.length() - 1));

if (utf == 1) {
if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
is_tokenizable = true;
break;
}
i++;
} else {
// skip over the rest of multibyte utf sequence
i += utf - 1;
}
}

if (!is_tokenizable) {
// Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
// it's faster to re-filter them here, since there are way less candidates now

// Calculate a total "utf" length of a token string representation
size_t utf8_str_len = 0;
for (unsigned i = 0; i < token.length();) {
utf8_str_len++;
i += utf8_len(token.at(i));
}

// And skip the ones which are one character
if (utf8_str_len > 1) {
// At this point what we have left are special tokens only
vocab.special_tokens_cache[token] = id;

// Count manually found special tokens
special_tokens_count_from_verification++;

// If this manually found special token is not marked as such, flag a mismatch
if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
special_tokens_definition_mismatch = true;
}
}
}
std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
[&] (const llama_vocab::id a, const llama_vocab::id b) {
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
}
}
);

if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size(),
special_tokens_count_by_type, vocab.id_to_token.size()
);
} else {
LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size()
);
}
LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
}
}

Expand Down Expand Up @@ -12738,7 +12660,7 @@ struct llm_tokenizer_wpm {
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}

void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
auto * token_map = &vocab.token_to_id;
const auto & token_map = vocab.token_to_id;

// normalize and split by whitespace
std::vector<std::string> words = preprocess(text);
Expand All @@ -12753,108 +12675,89 @@ struct llm_tokenizer_wpm {
}

// prepend phantom space
std::string word1 = "\xe2\x96\x81" + word;
int n = word1.size();
const std::string word1 = "\xe2\x96\x81" + word;
const int n = word1.size();

// we're at the start of a new word
int i = 0;
bool match_any = false;
const size_t current_tokens = output.size();

// we're at the start of a new word
// move through character position in word
while (i < n) {
for (int i = 0; i < n; ++i) {
// loop through possible match length
bool match = false;
for (int j = n; j > i; j--) {
auto it = token_map->find(word1.substr(i, j - i));
if (it != token_map->end()) {
auto it = token_map.find(word1.substr(i, j - i));
if (it != token_map.end()) {
output.push_back(it->second);
match = true;
match_any = true;
i = j;
i = j - 1;
break;
}
}

// must be an unknown character
if (!match) {
i++;
if (!match) { // discard all
output.resize(current_tokens);
break; // and discard next tokens
}
}

// we didn't find any matches for this word
if (!match_any) {
if (current_tokens == output.size()) {
output.push_back(vocab.special_unk_id);
}
}
}

std::vector<std::string> preprocess(const std::string & text) {
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));

// strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation
std::string new_str = "";
for (uint32_t code : cpts_nfd) {
const codepoint_flags flags = unicode_cpt_flags(code);
if (flags.is_accent_mark || flags.is_control) {
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
std::vector<std::string> words(1, "");

for (const char32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt);

if (flags.is_whitespace) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
continue;
}
code = unicode_tolower(code);
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
code = ' ';
}
std::string s = unicode_cpt_to_utf8(code);
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " ";
new_str += s;
new_str += " ";
} else {
new_str += s;

assert (!flags.is_separator);
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
continue;
}
}

// split by whitespace
uint64_t l = 0;
uint64_t r = 0;
std::vector<std::string> words;
while (r < new_str.size()) {
// if is whitespace
if (isspace(new_str[r], std::locale::classic())) {
if (r > l) words.push_back(new_str.substr(l, (r - l)));
l = r + 1;
r = l;
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
words.back() = s; // single char word
words.emplace_back(); // start a new word
} else {
r += 1;
words.back() += s; // append char to word
}
}
if (r > l) {
words.push_back(new_str.substr(l, (r - l)));
}
return words;
}

bool is_ascii_punct(uint32_t code) {
if (code > 0xFF) {
return false;
if (!words.back().size()) {
words.pop_back();
}
auto c = char(static_cast<unsigned char>(code));
return ispunct(c, std::locale::classic());

return words;
}

bool is_chinese_char(uint32_t cpt) {
if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
(cpt >= 0x3400 && cpt <= 0x4DBF) ||
static bool is_chinese_char(uint32_t cpt) {
return
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
(cpt >= 0x03400 && cpt <= 0x04DBF) ||
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
(cpt >= 0xF900 && cpt <= 0xFAFF) ||
(cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
(cpt >= 0x3000 && cpt <= 0x303F) ||
(cpt >= 0xFF00 && cpt <= 0xFFEF)) {
return true; // NOLINT
}
return false;
(cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
(cpt >= 0x2F800 && cpt <= 0x2FA1F);
//(cpt >= 0x3000 && cpt <= 0x303F) ||
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
}

const llama_vocab & vocab;
Expand Down Expand Up @@ -12898,9 +12801,8 @@ struct fragment_buffer_variant {

static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
for (const auto & st: vocab.special_tokens_cache) {
const auto & special_token = st.first;
const auto & special_id = st.second;
for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
const auto & special_token = vocab.id_to_token[special_id].text;

// for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
Expand All @@ -12909,7 +12811,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<

// if a fragment is text ( not yet processed )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto * raw_text = &(fragment.raw_text);
auto & raw_text = fragment.raw_text;

auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length;
Expand All @@ -12919,7 +12821,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// find the first occurrence of a given special token in this fragment
// passing offset argument only limit the "search area" but match coordinates
// are still relative to the source full raw_text
auto match = raw_text->find(special_token, raw_text_base_offset);
auto match = raw_text.find(special_token, raw_text_base_offset);

// no occurrences found, stop processing this fragment for a given special token
if (match == std::string::npos) break;
Expand All @@ -12938,7 +12840,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// left
const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset;
buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);

#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
Expand All @@ -12954,7 +12856,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
const int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);

#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
Expand Down
20 changes: 13 additions & 7 deletions tests/test-tokenizer-random.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
for m in range(iterations):
rand.seed(m)
words = rand.choices(special_tokens, k=500)
if tokenizer.add_bos_token: # skip spam warning of double BOS
while words and words[0] == tokenizer.bos_token:
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
words.pop(0)
if tokenizer.add_bos_token: # drop all starting BOS
words.pop(0)
yield "".join(words)

Expand Down Expand Up @@ -293,15 +295,17 @@ def main(argv: list[str] = None):
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)

tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", True)
tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", False)

def func_tokenize1(text: str):
return model.tokenize(text, add_special=True, parse_special=True)

def func_tokenize2(text: str):
return tokenizer.encode(text, add_special_tokens=True)

ids = func_tokenize2("a")
assert 1 <= len(ids) <= 3
add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)

vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
Expand All @@ -324,8 +328,10 @@ def func_tokenize2(text: str):
# import os
# tokenizers = os.listdir(path_tokenizers)
tokenizers = [
"llama-spm", # SPM
"phi-3", # SPM
# "llama-spm", # SPM
# "phi-3", # SPM
"jina-v2-en", # WPM
"bert-bge", # WPM
]

for tokenizer in tokenizers:
Expand Down
Loading