Skip to content

Commit

Permalink
Add back top_k (#56)
Browse files Browse the repository at this point in the history
* Add back top_k

* Update utils.cpp

* Update utils.h

---------

Co-authored-by: Bill Hamilton <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2023
1 parent eb062bb commit 02f0c6f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 89 deletions.
3 changes: 2 additions & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ int main(int argc, char ** argv) {

if (i >= embd_inp.size()) {
// sample next token
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
Expand All @@ -836,7 +837,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();

id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
Expand Down
79 changes: 4 additions & 75 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
return true;
}

gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();

std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);

{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}

void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
// find the top K tokens
std::partial_sort(
logits_id.begin(),
Expand All @@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p(
});

logits_id.resize(top_k);

double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}

// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());

double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}

// normalize the probs
for (auto & p : probs) {
p /= sum;
}

if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}

cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}

//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);

std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);

return logits_id[idx].second;
}

gpt_vocab::id llama_sample_top_p(
gpt_vocab::id llama_sample_top_p_top_k(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
Expand All @@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p(
}
}

std::sort(
logits_id.begin(),
logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
sample_top_k(logits_id, top_k);

double maxl = -INFINITY;
for (const auto & kv : logits_id) {
Expand Down
19 changes: 6 additions & 13 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize

// sampling parameters
int32_t top_k = 40; // unused
int32_t top_k = 40;
float top_p = 0.95f;
float temp = 0.80f;
float repeat_penalty = 1.30f;
Expand Down Expand Up @@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
// TODO: not sure if this implementation is correct
// TODO: temperature is not implemented
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng);

gpt_vocab::id llama_sample_top_p(
gpt_vocab::id llama_sample_top_p_top_k(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
int top_k,
double top_p,
double temp,
std::mt19937 & rng);

// filer to top K tokens from list of logits
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);

//
// Quantization
//
Expand Down

0 comments on commit 02f0c6f

Please sign in to comment.