diff --git a/common/sampling.cpp b/common/sampling.cpp index da529514512fd7..97a8d2b04c8ffa 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -189,6 +189,9 @@ static llama_token llama_sampling_sample_impl( std::vector original_logits; auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); + if (cur_p.data == NULL) { + return -1; + } if (ctx_sampling->grammar != NULL && !is_resampling) { GGML_ASSERT(!original_logits.empty()); } @@ -286,7 +289,7 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); if (!logits) { - throw std::runtime_error("llama_get_logits_ith failed"); + return {NULL, 0, false}; } if (ctx_sampling->grammar != NULL && !apply_grammar) { @@ -303,7 +306,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); if (!logits_guidance) { - throw std::runtime_error("llama_get_logits_ith failed"); + return {NULL, 0, false}; } llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index afac145f63934c..f1a86346224f29 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -530,6 +530,9 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index a6d67e5d72cd28..e293edda018e61 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_llama, int * n_past) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); + GGML_ASSERT(id != -1); llama_sampling_accept(ctx_sampling, ctx_llama, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 9c3540b2008c20..dcf3c21751743c 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -159,6 +159,9 @@ int main(int argc, char ** argv) { // sample first token { id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -284,6 +287,9 @@ int main(int argc, char ** argv) { // sample the next token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -361,6 +367,9 @@ int main(int argc, char ** argv) { // sample from the last level for (int i = 0; i < W; i++) { tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + if (tokens_j[N - 2][i] == -1) { + return 1; + } } } else { for (int i = 0; i < W; i++) { diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index eebbd00a58e66c..01e02f182649df 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -131,6 +131,7 @@ int main(int argc, char ** argv){ while (true) { // sample from the target model llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); + GGML_ASSERT(id != -1); llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 832b51ee086bec..3c2ba844b86f62 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -706,6 +706,9 @@ int main(int argc, char ** argv) { } const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + if (id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7c5595d6edb2dc..91fdd73fbeeeb4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -341,6 +341,7 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); + GGML_ASSERT(id != -1); llama_sampling_accept(client.ctx_sampling, ctx, id, true); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6af5cb96e6d131..d2b6bd335d1669 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2257,6 +2257,9 @@ struct server_context { completion_token_output result; const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + if (id == -1) { + continue; // keep going, don't crash, already logged + } llama_sampling_accept(slot.ctx_sampling, ctx, id, true); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 12e46fbc91a242..78932cd0c40ad4 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -229,6 +229,9 @@ int main(int argc, char ** argv) { // stochastic verification llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); + if (dist_tgt.data == NULL) { + return 1; + } llama_sample_softmax(ctx_tgt, &dist_tgt); float p_tgt = 0, p_dft = 0; @@ -337,6 +340,9 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + if (token_id == -1) { + return 1; + } llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); @@ -457,7 +463,9 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); + if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) { + return -1; + } const auto & cur_p = drafts[s].ctx_sampling->cur;