Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/repeat penalty #20

Merged
merged 6 commits into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ int main(int argc, char ** argv) {
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
}
printf("\n");
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("\n\n");

std::vector<gpt_vocab::id> embd;
Expand All @@ -801,6 +801,10 @@ int main(int argc, char ** argv) {
size_t mem_per_token = 0;
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);

int last_n_size = params.repeat_last_n;
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);

for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.size() > 0) {
Expand All @@ -821,6 +825,7 @@ int main(int argc, char ** argv) {
// sample next token
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;

const int n_vocab = model.hparams.n_vocab;

Expand All @@ -829,7 +834,10 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();

id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);

t_sample_us += ggml_time_us() - t_start_sample_us;
}
Expand All @@ -840,6 +848,8 @@ int main(int argc, char ** argv) {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[k]);
if (embd.size() > params.n_batch) {
break;
}
Expand Down
22 changes: 21 additions & 1 deletion utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
} else if (arg == "--repeat_last_n") {
params.repeat_last_n = std::stoi(argv[++i]);
} else if (arg == "--repeat_penalty") {
params.repeat_penalty = std::stof(argv[++i]);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
Expand Down Expand Up @@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
Expand Down Expand Up @@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
gpt_vocab::id llama_sample_top_p(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
double top_p,
double temp,
std::mt19937 & rng) {
Expand All @@ -383,7 +391,19 @@ gpt_vocab::id llama_sample_top_p(
{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
// repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if(logits[i] < 0.0) {
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
}

} else {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
4 changes: 4 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize

// sampling parameters
int32_t top_k = 40; // unused
float top_p = 0.95f;
float temp = 0.80f;
float repeat_penalty = 1.30f;

int32_t n_batch = 8; // batch size for prompt processing

Expand Down Expand Up @@ -89,6 +91,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
gpt_vocab::id llama_sample_top_p(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
double top_p,
double temp,
std::mt19937 & rng);
Expand Down