Skip to content

Commit

Permalink
Make sampling not throw exception
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed May 22, 2024
1 parent aa3094c commit 6c6d55b
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 3 deletions.
7 changes: 5 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ static llama_token llama_sampling_sample_impl(

std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
if (cur_p.data == NULL) {
return -1;
}
if (ctx_sampling->grammar != NULL && !is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
Expand Down Expand Up @@ -286,7 +289,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
return {NULL, 0, false};
}

if (ctx_sampling->grammar != NULL && !apply_grammar) {
Expand All @@ -303,7 +306,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
if (!logits_guidance) {
throw std::runtime_error("llama_get_logits_ith failed");
return {NULL, 0, false};
}
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}
Expand Down
3 changes: 3 additions & 0 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
1 change: 1 addition & 0 deletions examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
GGML_ASSERT(id != -1);
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
Expand Down
9 changes: 9 additions & 0 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ int main(int argc, char ** argv) {
// sample first token
{
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down Expand Up @@ -284,6 +287,9 @@ int main(int argc, char ** argv) {

// sample the next token
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down Expand Up @@ -361,6 +367,9 @@ int main(int argc, char ** argv) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
if (tokens_j[N - 2][i] == -1) {
return 1;
}
}
} else {
for (int i = 0; i < W; i++) {
Expand Down
1 change: 1 addition & 0 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ int main(int argc, char ** argv){
while (true) {
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
GGML_ASSERT(id != -1);

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
3 changes: 3 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,9 @@ int main(int argc, char ** argv) {
}

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
if (id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);

Expand Down
1 change: 1 addition & 0 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);

const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
GGML_ASSERT(id != -1);

llama_sampling_accept(client.ctx_sampling, ctx, id, true);

Expand Down
3 changes: 3 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,9 @@ struct server_context {

completion_token_output result;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
if (id == -1) {
continue; // keep going, don't crash, already logged
}

llama_sampling_accept(slot.ctx_sampling, ctx, id, true);

Expand Down
10 changes: 9 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ int main(int argc, char ** argv) {
// stochastic verification

llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
if (dist_tgt.data == NULL) {
return 1;
}
llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0;

Expand Down Expand Up @@ -337,6 +340,9 @@ int main(int argc, char ** argv) {
// sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
if (token_id == -1) {
return 1;
}

llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);

Expand Down Expand Up @@ -457,7 +463,9 @@ int main(int argc, char ** argv) {
continue;
}

llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) {
return -1;
}

const auto & cur_p = drafts[s].ctx_sampling->cur;

Expand Down

0 comments on commit 6c6d55b

Please sign in to comment.