diff --git a/.gitignore b/.gitignore index 8a495199e75..5ca3702c331 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ build/ build-em/ build-debug/ build-release/ +build-static/ build-sanitize-addr/ build-sanitize-thread/ @@ -18,6 +19,7 @@ build-sanitize-thread/ /bench sync.sh +libwhisper.a libwhisper.so compile_commands.json diff --git a/whisper.cpp b/whisper.cpp index beb39f15cd9..b673db13264 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -427,7 +427,9 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; - std::vector work_logits; // used to avoid allocations + // used to avoid allocations + std::vector work_logits; + std::vector work_logprobs; // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; @@ -2614,6 +2616,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, + /*.max_initial_timestamp =*/ 1.0, + /*.greedy =*/ { /*.dummy =*/ 0, }, @@ -2664,6 +2668,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, + /*.max_initial_timestamp =*/ 1.0, + /*.greedy =*/ { /*.dummy =*/ 0, }, @@ -2763,17 +2769,120 @@ static struct whisper_token_data whisper_sample_next_token( // extract the logits for the last token // we will be mutating and therefore we don't want to use the ctx->logits buffer directly - auto & logits = ctx->work_logits; + auto & logits = ctx->work_logits; + auto & logprobs = ctx->work_logprobs; { logits.resize(n_logits); memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); + + // will be populated a bit later + logprobs.resize(n_logits); } + // apply logit filters here // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 - // TODO: apply logit filters here { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++ i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++ i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_timestamp > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_timestamp/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++ i) { + logsumexp += expf(logits[i] - logit_max); + } + logsumexp = logf(logsumexp) + logit_max; + for (int i = 0; i < n_logits; ++ i) { + logprobs[i] = logits[i] - logsumexp; + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++ i) { + logsumexp += expf(logprobs[i] - logprob_max); + } + logsumexp = logf(logsumexp) + logprob_max; + timestamp_logprob = logsumexp; + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++ i) { + logits[i] = -INFINITY; + } + } + } } + // print first 100 logits - token string : logit + for (int i = 0; i < 100; i++) { + const auto token = vocab.id_to_token.at(i); + const auto logit = logits[i]; + printf("%s : %f\n", token.c_str(), logit); + } + + // "And", "and", " And", " and" + printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + switch (params.strategy) { case WHISPER_SAMPLING_GREEDY: { diff --git a/whisper.h b/whisper.h index 6b163fda829..f4d30baee77 100644 --- a/whisper.h +++ b/whisper.h @@ -261,8 +261,11 @@ extern "C" { // for auto-detection, set to nullptr, "" or "auto" const char * language; + // common decoding parameters: bool suppress_blank; + float max_initial_timestamp; + struct { int dummy; } greedy;