Skip to content

Commit

Permalink
whisper : prepare infra for new decoding strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 18, 2022
1 parent 1d716d6 commit f06b991
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 26 deletions.
97 changes: 73 additions & 24 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
#include <vector>
#include <regex>

#define WHISPER_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)

#define USE_FLASH_ATTN
//#define USE_FLASH_FF

Expand Down Expand Up @@ -417,8 +425,9 @@ struct whisper_context {
std::vector<float> logits;

std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;

std::vector<whisper_token> prompt_past;
std::vector<float> work_logits; // used to avoid allocations

// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
Expand Down Expand Up @@ -1864,12 +1873,12 @@ static whisper_token_data whisper_sample_best(
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};

int n_logits = vocab.id_to_token.size();
const int n_probs = vocab.id_to_token.size();

std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
probs_id.reserve(n_probs);

for (int i = 0; i < n_logits; i++) {
for (int i = 0; i < n_probs; i++) {
probs_id.push_back(std::make_pair(probs[i], i));
}

Expand All @@ -1883,12 +1892,12 @@ static whisper_token_data whisper_sample_best(
}

const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
const auto i1 = is_initial ? vocab.token_beg + 101 : n_probs;

// the initial timestamp cannot be larger than 100
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial) {
for (int i = i0; i < n_logits; ++ i) {
for (int i = i0; i < n_probs; ++ i) {
probs_id[i].first = -INFINITY;
}
}
Expand Down Expand Up @@ -2608,12 +2617,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.language =*/ "en",

/*.suppress_blank =*/ true,

/*.greedy =*/ {
/*.n_past =*/ 0,
/*.dummy =*/ 0,
},

/*.beam_search =*/ {
/*.n_past =*/ -1,
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
},
Expand Down Expand Up @@ -2657,12 +2667,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.language =*/ "en",

/*.suppress_blank =*/ true,

/*.greedy =*/ {
/*.n_past =*/ -1,
/*.dummy =*/ 0,
},

/*.beam_search =*/ {
/*.n_past =*/ 0,
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
},
Expand Down Expand Up @@ -2741,6 +2752,50 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
return res;
}

static struct whisper_token_data whisper_sample_next_token(
struct whisper_context * ctx,
struct whisper_full_params params,
const std::vector<whisper_token> & prompt,
const std::vector<whisper_token_data> & tokens_cur) {
struct whisper_token_data result = {};

const auto & vocab = ctx->vocab;

const bool is_initial = tokens_cur.size() == 0;
const int n_logits = vocab.id_to_token.size();

WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);

// extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx->logits buffer directly
auto & logits = ctx->work_logits;
{
logits.resize(n_logits);
memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float));
}

// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
// TODO: apply logit filters here
{
}

switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
{
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274
// TODO: implement
result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364
// TODO: implement
} break;
}

return result;
}

int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
Expand Down Expand Up @@ -2870,7 +2925,6 @@ int whisper_full(
return -4;
}

int n_past = 0;
prompt.clear();

// if we have already generated some text, use it as a prompt to condition the next generation
Expand All @@ -2886,20 +2940,21 @@ int whisper_full(

prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());

int seek_delta = 100*WHISPER_CHUNK_SIZE;

// print the prompt
//printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
//}
//printf("\n\n");

int n_past = 0;
int seek_delta = 100*WHISPER_CHUNK_SIZE;

// the accumulated transcription in the current interation
int result_len = 0;
tokens_cur.clear();

bool failed = false;
bool failed = false; // has the current segment failed to decode?
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?

for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
Expand All @@ -2911,15 +2966,10 @@ int whisper_full(
n_past += prompt.size();
prompt.clear();

// very basic greedy sampling strategy:
//
// - always take the most probable token
//
// more sophisticated sampling strategies could be implemented here, but we keep it simple
// feel free to experiment!
//
// sample the next token based on the selected decoding strategy + parameters
// also, update the sliding window position based on the sampled timestamp tokens
{
const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur);

// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
Expand Down Expand Up @@ -2974,8 +3024,7 @@ int whisper_full(
}

// sometimes, the decoding can get stuck in a repetition loop
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
// the sliding window by 1 second
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
failed = true;
break;
Expand Down
5 changes: 3 additions & 2 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,13 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto"
const char * language;

bool suppress_blank;

struct {
int n_past;
int dummy;
} greedy;

struct {
int n_past;
int beam_width;
int n_best;
} beam_search;
Expand Down

0 comments on commit f06b991

Please sign in to comment.