From 149603ea0c6b435342506ec648088f572357dba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 6 May 2024 21:53:28 +0200 Subject: [PATCH] tokenization: no double BOS tokens --- llama.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/llama.cpp b/llama.cpp index aeb5c08df64a5c..b45446e90e370b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12674,11 +12674,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & // tokenizer.encode('', add_special_tokens=True) returns [1] // tokenizer.encode('', add_special_tokens=False) returns [] - if (add_special && vocab.special_add_bos != 0) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); - } - for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { // without adding this leading whitespace, we do not get the same results as the original tokenizer @@ -12705,6 +12700,11 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } + if (add_special && vocab.special_add_bos != 0 && output[0] != vocab.special_bos_id) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.insert(output.begin(), vocab.special_bos_id); + } + if (add_special && vocab.special_add_eos == 1) { GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); @@ -12712,11 +12712,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & } break; case LLAMA_VOCAB_TYPE_BPE: { - if (add_special && vocab.special_add_bos != 0) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); - } - for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -12731,6 +12726,11 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } + if (add_special && vocab.special_add_bos != 0 && output[0] != vocab.special_bos_id) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.insert(output.begin(), vocab.special_bos_id); + } + GGML_ASSERT(vocab.special_add_eos != 1); } break; case LLAMA_VOCAB_TYPE_WPM: