Skip to content

Commit

Permalink
Support llama3 in chat in run.cpp (pytorch#486)
Browse files Browse the repository at this point in the history
* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver
  • Loading branch information
metascroy authored and malfet committed Jul 17, 2024
1 parent a14f9f8 commit f25ef37
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 67 deletions.
278 changes: 213 additions & 65 deletions runner/run.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
/* Inference for Llama-2 Transformer model in pure C++ */
#include <cstdint>
#include <cstdlib>
#include <ctype.h>
#include <iterator>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <tokenizer.h>
#include <string>


#ifdef DEBUG
#include <cassert>
Expand Down Expand Up @@ -485,27 +490,184 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
// python reference and that seemed ok, but this was not thoroughly tested and
// is not safely implemented, it's more a proof of concept atm.

enum class ModelType {
unknown,
llama2,
llama3,
};

ModelType get_model_type(Tokenizer* tokenizer) {
if (BPETokenizer* t = dynamic_cast<BPETokenizer*>(tokenizer)) {
return ModelType::llama2;
} else if (Tiktoken* t = dynamic_cast<Tiktoken*>(tokenizer)) {
return ModelType::llama3;
} else {
return ModelType::unknown;
}
}

uint64_t get_eot_token(Tokenizer* tokenizer) {
ModelType model_type = get_model_type(tokenizer);

if (model_type == ModelType::llama2) {
// llama2 uses EOS as EOT token
return tokenizer->eos_tok();
}

if (model_type == ModelType::llama3) {
auto tokens = tokenizer->encode("<|eot_id|>", 0, 0);
return tokens[0];
}

fprintf(stderr, "No chat template implemnation for model type %d", model_type);
exit(EXIT_FAILURE);
}

std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, const char* cli_user_prompt, Tokenizer* tokenizer) {
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[512*2 + 200]; // the prompt template is ~170 characters. We use 200 to be safe.

if (cli_system_prompt != NULL) {
strcpy(system_prompt, cli_system_prompt);
} else {
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
}

if (cli_user_prompt != NULL) {
strcpy(user_prompt, cli_user_prompt);
} else {
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}

ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {

case ModelType::llama2:
if (system_prompt[0] != '\0') {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
system_prompt,
user_prompt
);
} else {
// const char prompt_template[] = ;
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"[INST] %s [/INST]",
user_prompt
);
}

// We need to add BOS token here and not in template because llama2 tokenizer
// does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, 1, 0);
break;

case ModelType::llama3:
if (system_prompt[0] != '\0') {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
system_prompt,
user_prompt
);
} else {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt
);
}
tokens = tokenizer->encode(rendered_prompt, 0, 0);
break;

default:
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
exit(EXIT_FAILURE);
}

#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif

return tokens;
}

std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
char user_prompt[512];
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We use 150 to be safe.

read_stdin("User: ", user_prompt, sizeof(user_prompt));

ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {

case ModelType::llama2:
// const char prompt_template[] = ;
snprintf(rendered_prompt, sizeof(rendered_prompt)-1, "[INST] %s [/INST]", user_prompt);

// We need to add BOS token here and not in template because llama2 tokenizer
// does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, /*bos*/1, /*eos*/0);
break;

case ModelType::llama3:
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt
);
tokens = tokenizer->encode(rendered_prompt, 0, 0);
break;

default:
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
exit(EXIT_FAILURE);
}


#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif

return tokens;
}


void chat(
Transformer* transformer,
Tokenizer* tokenizer,
Sampler* sampler,
const char* cli_user_prompt,
const char* cli_system_prompt,
int steps) {
// special tokens
const int SOS_TOKEN = tokenizer->bos_tok(); // token starts the assistant turn
const int EOS_TOKEN = tokenizer->eos_tok(); // token ends the assistant turn
const int SYSTEM_PROMPT_SIZE = 512;
const int USER_PROMPT_SIZE = 512;
const int RENDERED_PROMPT_SIZE = SYSTEM_PROMPT_SIZE + USER_PROMPT_SIZE + 128; // This is big enough to hold the expanded template



// buffers for reading the system prompt and user prompt from stdin
// you'll notice they are soomewhat haphazardly and unsafely set atm
char system_prompt[SYSTEM_PROMPT_SIZE];
char user_prompt[USER_PROMPT_SIZE];
char rendered_prompt[RENDERED_PROMPT_SIZE];
const uint64_t EOT_TOKEN = get_eot_token(tokenizer);
int num_prompt_tokens = 0;
std::vector<uint64_t> prompt_tokens;
int user_idx;
Expand All @@ -522,41 +684,10 @@ void chat(
if (user_turn) {
// get the (optional) system prompt at position 0
if (pos == 0) {
// at position 0, the user can also contribute a system prompt
if (cli_system_prompt == NULL) {
// system prompt was not passed in, attempt to get it from stdin
read_stdin(
"Enter system prompt (optional): ",
system_prompt,
sizeof(system_prompt));
} else {
// system prompt was passed in, use it
strcpy(system_prompt, cli_system_prompt);
}
}
// get the user prompt
if (pos == 0 && cli_user_prompt != NULL) {
// user prompt for position 0 was passed in, use it
strcpy(user_prompt, cli_user_prompt);
prompt_tokens = get_initial_prompt_tokens(cli_system_prompt, cli_user_prompt, tokenizer);
} else {
// otherwise get user prompt from stdin
read_stdin("User: ", user_prompt, sizeof(user_prompt));
prompt_tokens = get_next_user_prompt_tokens(tokenizer);
}
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
// We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
const char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
snprintf(
rendered_prompt, RENDERED_PROMPT_SIZE-1, system_template, system_prompt, user_prompt);
} else {
// Assistant should produce </s>, so we do not include it in template
// We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
const char user_template[] = "[INST] %s [/INST]";
snprintf(rendered_prompt, RENDERED_PROMPT_SIZE-1, user_template, user_prompt);
}

// encode the rendered prompt into tokens
prompt_tokens = tokenizer->encode(rendered_prompt, 1, 0);
num_prompt_tokens = prompt_tokens.size();

user_idx = 0; // reset the user index
Expand All @@ -578,19 +709,21 @@ void chat(
float* logits = forward(transformer, token, pos);
next = sample(sampler, logits);

// std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;

if (token == EOS_TOKEN) {

if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
user_turn = 1;
}

if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN) {
if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
std::string piece = tokenizer->decode(token, next);
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
// "unsafe" bytes
fflush(stdout);
}

if (next == EOS_TOKEN) {
if (next == EOT_TOKEN) {
printf("\n");
}
pos++;
Expand Down Expand Up @@ -619,6 +752,7 @@ void error_usage() {
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
fprintf(stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
exit(EXIT_FAILURE);
}

Expand All @@ -630,14 +764,17 @@ int main(int argc, char* argv[]) {
1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp =
0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int vocab_size = 32000;

int steps = 256; // number of steps to run for
const char* prompt = NULL; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
const char* mode = "generate"; // generate|chat
char* system_prompt =
NULL; // the (optional) system prompt to use in chat mode

int vocab_size = -1;
int llama_ver = 2;

#if defined(ET_USE_ADPATIVE_THREADS)
uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores();
if (num_performant_cores > 0) {
Expand Down Expand Up @@ -682,7 +819,10 @@ int main(int argc, char* argv[]) {
mode = argv[i + 1];
} else if (argv[i][1] == 'y') {
system_prompt = argv[i + 1];
} else {
} else if (argv[i][1] == 'l') {
llama_ver = atoi(argv[i+1]);
}
else {
error_usage();
}
}
Expand All @@ -697,27 +837,35 @@ int main(int argc, char* argv[]) {
if (steps < 0)
steps = 0;


if (vocab_size == -1) {
if (llama_ver == 2) {
vocab_size = 32000;
} else {
vocab_size = 128256;
}
}

// build the Transformer via the model .bin file
Transformer transformer;
build_transformer(&transformer, checkpoint_path, vocab_size, steps);

// build the Tokenizer via the tokenizer .bin file
Tokenizer* tokenizer = nullptr;

// Try to load using Tiktoken, if exception then switch to another tokenizer
try {
tokenizer =
new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
} catch (const std::invalid_argument&) {
fprintf(
stderr,
"Failed to load %s into a Tiktoken tokenizer. Trying sentencepiece tokenizer..\n",
tokenizer_path);
delete tokenizer;
tokenizer =
new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
switch (llama_ver) {
case 2:
tokenizer = new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
break;
case 3:
tokenizer = new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
break;

default:
fprintf(stderr, "Cannot load tokenizer for unrecognized llama version %d", llama_ver);
exit(EXIT_FAILURE);
}

// build the Sampler
Expand Down
8 changes: 6 additions & 2 deletions tokenizer/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(

Tiktoken::Tiktoken(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
: Tokenizer(vocab_size, bos_tok, eos_tok) {
_regex = _create_regex(_pattern);

_special_token_regex = _build_special_token_regex(_special_token_encoder);
// _regex = _create_regex(_pattern);
// _special_token_regex = _build_special_token_regex(_special_token_encoder);
}

void Tiktoken::load(const std::string& path) {
Expand All @@ -343,6 +343,10 @@ void Tiktoken::load(const std::string& path) {
_decoder = _build_decoder(_encoder);
_special_token_decoder = _build_decoder(_special_token_encoder);


_regex = _create_regex(_pattern);
_special_token_regex = _build_special_token_regex(_special_token_encoder);

initialized_ = true;
}

Expand Down

0 comments on commit f25ef37

Please sign in to comment.