Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fix reported top tokens for temperature 0 #7203

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

result->prev.resize(params.n_prev);

result->n_considered = 0;
result->n_valid = 0;

llama_sampling_set_rng_seed(result, params.seed);

Expand Down Expand Up @@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_considered = 0;
ctx->n_valid = 0;
}

void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
Expand Down Expand Up @@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl(
}
}

ctx_sampling->n_considered = cur_p.size;
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;

return id;
}
Expand Down
2 changes: 1 addition & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
size_t n_considered;
size_t n_valid; // Number of correct top tokens with correct probabilities.

std::mt19937 rng;
};
Expand Down
6 changes: 3 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2270,10 +2270,10 @@ struct server_context {

const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_considered = slot.ctx_sampling->n_considered;
const size_t n_valid = slot.ctx_sampling->n_valid;

// Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
}

Expand All @@ -2289,7 +2289,7 @@ struct server_context {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
Expand Down
Loading