Skip to content

Commit

Permalink
whisper : apply logit filters and compute logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 23, 2022
1 parent 86e0e50 commit 4c5527b
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ build/
build-em/
build-debug/
build-release/
build-static/
build-sanitize-addr/
build-sanitize-thread/

Expand All @@ -18,6 +19,7 @@ build-sanitize-thread/
/bench

sync.sh
libwhisper.a
libwhisper.so
compile_commands.json

Expand Down
115 changes: 112 additions & 3 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ struct whisper_context {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;

std::vector<float> work_logits; // used to avoid allocations
// used to avoid allocations
std::vector<float> work_logits;
std::vector<float> work_logprobs;

// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
Expand Down Expand Up @@ -2614,6 +2616,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.suppress_blank =*/ true,

/*.max_initial_timestamp =*/ 1.0,

/*.greedy =*/ {
/*.dummy =*/ 0,
},
Expand Down Expand Up @@ -2664,6 +2668,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.suppress_blank =*/ true,

/*.max_initial_timestamp =*/ 1.0,

/*.greedy =*/ {
/*.dummy =*/ 0,
},
Expand Down Expand Up @@ -2763,17 +2769,120 @@ static struct whisper_token_data whisper_sample_next_token(

// extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly
auto & logits = ctx->work_logits;
auto & logits = ctx->work_logits;
auto & logprobs = ctx->work_logprobs;
{
logits.resize(n_logits);
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));

// will be populated a bit later
logprobs.resize(n_logits);
}

// apply logit filters here
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
// TODO: apply logit filters here
{
// suppress blank
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
if (params.suppress_blank) {
if (is_initial) {
logits[vocab.token_eot] = -INFINITY;
logits[vocab.token_to_id.at(" ")] = -INFINITY;
}
}

// suppress <|notimestamps|> token
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;

// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;

if (last_was_timestamp) {
if (penultimate_was_timestamp) {
for (int i = vocab.token_beg; i < n_logits; ++ i) {
logits[i] = -INFINITY;
}
} else {
for (int i = 0; i < vocab.token_eot; ++ i) {
logits[i] = -INFINITY;
}
}
}
}

// the initial timestamp cannot be larger than max_initial_timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial && params.max_initial_timestamp > 0.0f) {
const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_timestamp/precision);

for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) {
logits[i] = -INFINITY;
}
}

// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++ i) {
logsumexp += expf(logits[i] - logit_max);
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++ i) {
logprobs[i] = logits[i] - logsumexp;
}
}

// if sum of probability over timestamps is above any other token, sample timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
{
// logsumexp over timestamps
float timestamp_logprob = -INFINITY;
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
for (int i = vocab.token_beg; i < n_logits; ++ i) {
logsumexp += expf(logprobs[i] - logprob_max);
}
logsumexp = logf(logsumexp) + logprob_max;
timestamp_logprob = logsumexp;
}

const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);

if (timestamp_logprob > max_text_token_logprob) {
for (int i = 0; i < vocab.token_beg; ++ i) {
logits[i] = -INFINITY;
}
}
}
}

// print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i);
const auto logit = logits[i];
printf("%s : %f\n", token.c_str(), logit);
}

// "And", "and", " And", " and"
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);

printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);

switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
{
Expand Down
3 changes: 3 additions & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,11 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto"
const char * language;

// common decoding parameters:
bool suppress_blank;

float max_initial_timestamp;

struct {
int dummy;
} greedy;
Expand Down

0 comments on commit 4c5527b

Please sign in to comment.