From f25ef3781c2dd83a1065a6c9bb7d54f6b88d696d Mon Sep 17 00:00:00 2001 From: metascroy <161522778+metascroy@users.noreply.github.com> Date: Sat, 27 Apr 2024 11:19:53 -0700 Subject: [PATCH] Support llama3 in chat in run.cpp (#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver --- runner/run.cpp | 278 +++++++++++++++++++++++++++++++---------- tokenizer/tiktoken.cpp | 8 +- 2 files changed, 219 insertions(+), 67 deletions(-) diff --git a/runner/run.cpp b/runner/run.cpp index 3c6b33e03b..07d67be08d 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -1,5 +1,8 @@ /* Inference for Llama-2 Transformer model in pure C++ */ +#include +#include #include +#include #include #include #include @@ -7,6 +10,8 @@ #include #include #include +#include + #ifdef DEBUG #include @@ -485,6 +490,175 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) { // python reference and that seemed ok, but this was not thoroughly tested and // is not safely implemented, it's more a proof of concept atm. +enum class ModelType { + unknown, + llama2, + llama3, +}; + +ModelType get_model_type(Tokenizer* tokenizer) { + if (BPETokenizer* t = dynamic_cast(tokenizer)) { + return ModelType::llama2; + } else if (Tiktoken* t = dynamic_cast(tokenizer)) { + return ModelType::llama3; + } else { + return ModelType::unknown; + } +} + +uint64_t get_eot_token(Tokenizer* tokenizer) { + ModelType model_type = get_model_type(tokenizer); + + if (model_type == ModelType::llama2) { + // llama2 uses EOS as EOT token + return tokenizer->eos_tok(); + } + + if (model_type == ModelType::llama3) { + auto tokens = tokenizer->encode("<|eot_id|>", 0, 0); + return tokens[0]; + } + + fprintf(stderr, "No chat template implemnation for model type %d", model_type); + exit(EXIT_FAILURE); +} + +std::vector get_initial_prompt_tokens(const char* cli_system_prompt, const char* cli_user_prompt, Tokenizer* tokenizer) { + char system_prompt[512]; + char user_prompt[512]; + char rendered_prompt[512*2 + 200]; // the prompt template is ~170 characters. We use 200 to be safe. + + if (cli_system_prompt != NULL) { + strcpy(system_prompt, cli_system_prompt); + } else { + read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt)); + } + + if (cli_user_prompt != NULL) { + strcpy(user_prompt, cli_user_prompt); + } else { + read_stdin("User: ", user_prompt, sizeof(user_prompt)); + } + + ModelType model_type = get_model_type(tokenizer); + std::vector tokens; + + switch (model_type) { + + case ModelType::llama2: + if (system_prompt[0] != '\0') { + snprintf( + rendered_prompt, + sizeof(rendered_prompt)-1, + "[INST] <>\n%s\n<>\n\n%s [/INST]", + system_prompt, + user_prompt + ); + } else { + // const char prompt_template[] = ; + snprintf( + rendered_prompt, + sizeof(rendered_prompt)-1, + "[INST] %s [/INST]", + user_prompt + ); + } + + // We need to add BOS token here and not in template because llama2 tokenizer + // does not pattern match special tokens + tokens = tokenizer->encode(rendered_prompt, 1, 0); + break; + + case ModelType::llama3: + if (system_prompt[0] != '\0') { + snprintf( + rendered_prompt, + sizeof(rendered_prompt)-1, + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + system_prompt, + user_prompt + ); + } else { + snprintf( + rendered_prompt, + sizeof(rendered_prompt)-1, + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + user_prompt + ); + } + tokens = tokenizer->encode(rendered_prompt, 0, 0); + break; + + default: + fprintf(stderr, "No chat template implemnation for model type %d", model_type); + exit(EXIT_FAILURE); + } + + #ifdef DEBUG + std::cerr << "Start of rendered prompt:" << std::endl; + std::cerr << rendered_prompt; + std::cerr << "End of rendered prompt:" << std::endl; + std::cerr << "Encoded prompt: "; + for (int i = 0; i < tokens.size(); i++) { + std::cerr << tokens[i] << ", "; + } + std::cerr << std::endl << std::flush; + #endif + + return tokens; +} + +std::vector get_next_user_prompt_tokens(Tokenizer* tokenizer) { + char user_prompt[512]; + char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We use 150 to be safe. + + read_stdin("User: ", user_prompt, sizeof(user_prompt)); + + ModelType model_type = get_model_type(tokenizer); + std::vector tokens; + + switch (model_type) { + + case ModelType::llama2: + // const char prompt_template[] = ; + snprintf(rendered_prompt, sizeof(rendered_prompt)-1, "[INST] %s [/INST]", user_prompt); + + // We need to add BOS token here and not in template because llama2 tokenizer + // does not pattern match special tokens + tokens = tokenizer->encode(rendered_prompt, /*bos*/1, /*eos*/0); + break; + + case ModelType::llama3: + snprintf( + rendered_prompt, + sizeof(rendered_prompt)-1, + "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + user_prompt + ); + tokens = tokenizer->encode(rendered_prompt, 0, 0); + break; + + default: + fprintf(stderr, "No chat template implemnation for model type %d", model_type); + exit(EXIT_FAILURE); + } + + + #ifdef DEBUG + std::cerr << "Start of rendered prompt:" << std::endl; + std::cerr << rendered_prompt; + std::cerr << "End of rendered prompt:" << std::endl; + std::cerr << "Encoded prompt: "; + for (int i = 0; i < tokens.size(); i++) { + std::cerr << tokens[i] << ", "; + } + std::cerr << std::endl << std::flush; + #endif + + return tokens; +} + + void chat( Transformer* transformer, Tokenizer* tokenizer, @@ -492,20 +666,8 @@ void chat( const char* cli_user_prompt, const char* cli_system_prompt, int steps) { - // special tokens - const int SOS_TOKEN = tokenizer->bos_tok(); // token starts the assistant turn - const int EOS_TOKEN = tokenizer->eos_tok(); // token ends the assistant turn - const int SYSTEM_PROMPT_SIZE = 512; - const int USER_PROMPT_SIZE = 512; - const int RENDERED_PROMPT_SIZE = SYSTEM_PROMPT_SIZE + USER_PROMPT_SIZE + 128; // This is big enough to hold the expanded template - - - // buffers for reading the system prompt and user prompt from stdin - // you'll notice they are soomewhat haphazardly and unsafely set atm - char system_prompt[SYSTEM_PROMPT_SIZE]; - char user_prompt[USER_PROMPT_SIZE]; - char rendered_prompt[RENDERED_PROMPT_SIZE]; + const uint64_t EOT_TOKEN = get_eot_token(tokenizer); int num_prompt_tokens = 0; std::vector prompt_tokens; int user_idx; @@ -522,41 +684,10 @@ void chat( if (user_turn) { // get the (optional) system prompt at position 0 if (pos == 0) { - // at position 0, the user can also contribute a system prompt - if (cli_system_prompt == NULL) { - // system prompt was not passed in, attempt to get it from stdin - read_stdin( - "Enter system prompt (optional): ", - system_prompt, - sizeof(system_prompt)); - } else { - // system prompt was passed in, use it - strcpy(system_prompt, cli_system_prompt); - } - } - // get the user prompt - if (pos == 0 && cli_user_prompt != NULL) { - // user prompt for position 0 was passed in, use it - strcpy(user_prompt, cli_user_prompt); + prompt_tokens = get_initial_prompt_tokens(cli_system_prompt, cli_user_prompt, tokenizer); } else { - // otherwise get user prompt from stdin - read_stdin("User: ", user_prompt, sizeof(user_prompt)); + prompt_tokens = get_next_user_prompt_tokens(tokenizer); } - // render user/system prompts into the Llama 2 Chat schema - if (pos == 0 && system_prompt[0] != '\0') { - // We do not add because that is added by tokenizer->encode(x, 1, 0) - const char system_template[] = "[INST] <>\n%s\n<>\n\n%s [/INST]"; - snprintf( - rendered_prompt, RENDERED_PROMPT_SIZE-1, system_template, system_prompt, user_prompt); - } else { - // Assistant should produce , so we do not include it in template - // We do not add because that is added by tokenizer->encode(x, 1, 0) - const char user_template[] = "[INST] %s [/INST]"; - snprintf(rendered_prompt, RENDERED_PROMPT_SIZE-1, user_template, user_prompt); - } - - // encode the rendered prompt into tokens - prompt_tokens = tokenizer->encode(rendered_prompt, 1, 0); num_prompt_tokens = prompt_tokens.size(); user_idx = 0; // reset the user index @@ -578,19 +709,21 @@ void chat( float* logits = forward(transformer, token, pos); next = sample(sampler, logits); + // std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl; - if (token == EOS_TOKEN) { + + if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) { user_turn = 1; } - if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN) { + if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) { std::string piece = tokenizer->decode(token, next); safe_printf(piece.c_str()); // same as printf("%s", piece), but skips // "unsafe" bytes fflush(stdout); } - if (next == EOS_TOKEN) { + if (next == EOT_TOKEN) { printf("\n"); } pos++; @@ -619,6 +752,7 @@ void error_usage() { fprintf(stderr, " -z optional path to custom tokenizer\n"); fprintf(stderr, " -m mode: generate|chat, default: generate\n"); fprintf(stderr, " -y (optional) system prompt in chat mode\n"); + fprintf(stderr, " -l (optional) llama version (2 or 3). Defaults to 2.\n"); exit(EXIT_FAILURE); } @@ -630,7 +764,7 @@ int main(int argc, char* argv[]) { 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower - int vocab_size = 32000; + int steps = 256; // number of steps to run for const char* prompt = NULL; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default @@ -638,6 +772,9 @@ int main(int argc, char* argv[]) { char* system_prompt = NULL; // the (optional) system prompt to use in chat mode + int vocab_size = -1; + int llama_ver = 2; + #if defined(ET_USE_ADPATIVE_THREADS) uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores(); if (num_performant_cores > 0) { @@ -682,7 +819,10 @@ int main(int argc, char* argv[]) { mode = argv[i + 1]; } else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; - } else { + } else if (argv[i][1] == 'l') { + llama_ver = atoi(argv[i+1]); + } + else { error_usage(); } } @@ -697,6 +837,15 @@ int main(int argc, char* argv[]) { if (steps < 0) steps = 0; + + if (vocab_size == -1) { + if (llama_ver == 2) { + vocab_size = 32000; + } else { + vocab_size = 128256; + } + } + // build the Transformer via the model .bin file Transformer transformer; build_transformer(&transformer, checkpoint_path, vocab_size, steps); @@ -704,20 +853,19 @@ int main(int argc, char* argv[]) { // build the Tokenizer via the tokenizer .bin file Tokenizer* tokenizer = nullptr; - // Try to load using Tiktoken, if exception then switch to another tokenizer - try { - tokenizer = - new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2); - tokenizer->load(tokenizer_path); - } catch (const std::invalid_argument&) { - fprintf( - stderr, - "Failed to load %s into a Tiktoken tokenizer. Trying sentencepiece tokenizer..\n", - tokenizer_path); - delete tokenizer; - tokenizer = - new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2); - tokenizer->load(tokenizer_path); + switch (llama_ver) { + case 2: + tokenizer = new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2); + tokenizer->load(tokenizer_path); + break; + case 3: + tokenizer = new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2); + tokenizer->load(tokenizer_path); + break; + + default: + fprintf(stderr, "Cannot load tokenizer for unrecognized llama version %d", llama_ver); + exit(EXIT_FAILURE); } // build the Sampler diff --git a/tokenizer/tiktoken.cpp b/tokenizer/tiktoken.cpp index f30488b8b0..11cec0c58a 100644 --- a/tokenizer/tiktoken.cpp +++ b/tokenizer/tiktoken.cpp @@ -331,9 +331,9 @@ std::pair, uint64_t> Tiktoken::_encode_with_special_token( Tiktoken::Tiktoken(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok) : Tokenizer(vocab_size, bos_tok, eos_tok) { - _regex = _create_regex(_pattern); - _special_token_regex = _build_special_token_regex(_special_token_encoder); + // _regex = _create_regex(_pattern); + // _special_token_regex = _build_special_token_regex(_special_token_encoder); } void Tiktoken::load(const std::string& path) { @@ -343,6 +343,10 @@ void Tiktoken::load(const std::string& path) { _decoder = _build_decoder(_encoder); _special_token_decoder = _build_decoder(_special_token_encoder); + + _regex = _create_regex(_pattern); + _special_token_regex = _build_special_token_regex(_special_token_encoder); + initialized_ = true; }