Skip to content

Commit

Permalink
build(dep): Update llama.cpp for new sampling API
Browse files Browse the repository at this point in the history
Issue #20
  • Loading branch information
grencez committed Apr 30, 2023
1 parent 768cdfc commit 4388992
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion dep/cmake_fetchcontent/llama_cpp.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FetchContent_Declare(
LlamaCpp
GIT_REPOSITORY "https://github.com/ggerganov/llama.cpp.git"
GIT_TAG "54bb60e26858be251a0eb3cb70f80322aff804a0"
GIT_TAG "c3ca7a5f0546c561eb278be3f2fe335795679e01"
)
FetchContent_MakeAvailable(LlamaCpp)
set(LlamaCpp_INCLUDE_DIRS "${LlamaCpp_SOURCE_DIR}" PARENT_SCOPE)
Expand Down
30 changes: 23 additions & 7 deletions src/chat/chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ rendezllama::generate_next_token(
{
// Zero probability for end-of-stream token.
// (Technically called "end-of-sentence", but it's treated as end-of-stream.)
llama_get_logits(ctx)[llama_token_eos()] = 0;
float* logits = llama_get_logits(ctx);
logits[llama_token_eos()] = 0;

if (preventing_newline) {
llama_get_logits(ctx)[rendezllama::newline_token(ctx)] = 0;
logits[rendezllama::newline_token(ctx)] = 0;
}

const size_t trailing_token_count = std::min(
Expand All @@ -133,10 +134,25 @@ rendezllama::generate_next_token(
penalized_tokens.end(),
extra_penalized_tokens.begin(), extra_penalized_tokens.end());

llama_token token_id = llama_sample_top_p_top_k(
ctx,
&penalized_tokens[0], penalized_tokens.size(),
opt.top_k, opt.top_p, opt.temp, opt.repeat_penalty);
std::vector<llama_token_data> candidates;
candidates.resize(llama_n_vocab(ctx));
for (llama_token i = 0; i < candidates.size(); i++) {
candidates[i] = llama_token_data{
i, logits[i], 0.0f,
};
}
llama_token_data_array candidates_data = {
candidates.data(), candidates.size(), false,
};

llama_sample_repetition_penalty(
ctx, &candidates_data,
penalized_tokens.data(), penalized_tokens.size(),
opt.repeat_penalty);
llama_sample_top_k(ctx, &candidates_data, opt.top_k);
llama_sample_top_p(ctx, &candidates_data, opt.top_p);
llama_sample_temperature(ctx, &candidates_data, opt.temp);
llama_token token_id= llama_sample_token(ctx, &candidates_data);

// If the improbable happens, just use a newline token.
if (token_id == llama_token_eos()) {
Expand Down Expand Up @@ -246,7 +262,7 @@ rendezllama::maybe_insert_answer_prompt(
}
if (answer_prompt_offset > 0) {
chat_tokens.insert(
chat_tokens.begin()+answer_prompt_offset,
chat_tokens.begin()+answer_prompt_offset,
answer_prompt_tokens.begin(),
answer_prompt_tokens.end());
}
Expand Down

0 comments on commit 4388992

Please sign in to comment.