Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fix incorrectly reported token probabilities #7125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

result->prev.resize(params.n_prev);

result->n_considered = 0;

llama_sampling_set_rng_seed(result, params.seed);

return result;
Expand Down Expand Up @@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_considered = 0;
}

void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
Expand Down Expand Up @@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
}
}

ctx_sampling->n_considered = cur_p.size;

return id;
}

Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
size_t n_considered;

std::mt19937 rng;
};
Expand Down
2 changes: 1 addition & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ node index.js

`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`

`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0`
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`

`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`

Expand Down
34 changes: 24 additions & 10 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2266,17 +2266,31 @@ struct server_context {
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;

const int32_t n_probs = slot.sparams.n_probs;
if (slot.sparams.temp <= 0 && n_probs > 0) {
// for llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &cur_p);
}
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_considered = slot.ctx_sampling->n_considered;

for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({
cur_p.data[i].id,
cur_p.data[i].p
});
// Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
}

if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i == 0 ? 1.0f : 0.0f
});
}
} else {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
}

if (!process_token(result, slot)) {
Expand Down
Loading