Skip to content

Commit

Permalink
feat(option): to set more sampling parameters
Browse files Browse the repository at this point in the history
These new ones are disabled by default.

Issue #20
  • Loading branch information
grencez committed May 1, 2023
1 parent 4388992 commit ac1c1f9
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 213 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ Remember, the recent chat content is just a rolling prompt concatenated to the e
- Repeat penalty.
- `/repeat_penalty 1.2` sets the repeated token penalty.
- `/repeat_window 20` penalizes the most recent 20 tokens from being generated.
- `/frequency_penalty 0.1` sets the frequency penalty. (0.0 is default, off)
- `/presence_penalty 0.1` sets presence penalty. (0.0 is default, off)
- `/less= some unwanted words` adds extra tokens to be penalized.
- `/dropless` clears the extra penalized tokens list.
- Generation parameters.
- `/temp 0.7` sets the temperature.
- `/top_k 40` sets the `top_k` parameter.
- `/top_p 0.9` sets the `top_p` parameter.
- `/tfs_z 0.9` sets Tail Free Sampling cutoff. (1.0 is default, off)
- `/typical_p 0.9` sets the Locally Typical Sampling cutoff. (1.0 is default, off)
- Execution parameters.
- `/thread_count 8` sets the number of threads.
- `/batch_count 8` sets the batch size.
Expand Down
32 changes: 25 additions & 7 deletions src/chat/chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ rendezllama::augment_chat_input(
}
}

static
llama_token
temperature_based_sample(
llama_token_data_array* candidates_data,
struct llama_context* ctx,
const rendezllama::ChatOptions& opt)
{
const unsigned keep_one = 1;
llama_sample_top_k(ctx, candidates_data, opt.top_k, keep_one);
llama_sample_tail_free(ctx, candidates_data, opt.tfs_z, keep_one);
llama_sample_typical(ctx, candidates_data, opt.typical_p, keep_one);
llama_sample_top_p(ctx, candidates_data, opt.top_p, keep_one);
llama_sample_temperature(ctx, candidates_data, opt.temp);
return llama_sample_token(ctx, candidates_data);
}

llama_token
rendezllama::generate_next_token(
struct llama_context* ctx,
Expand Down Expand Up @@ -141,18 +157,20 @@ rendezllama::generate_next_token(
i, logits[i], 0.0f,
};
}
llama_token_data_array candidates_data = {
llama_token_data_array candidates_data[1] = {{
candidates.data(), candidates.size(), false,
};
}};

llama_sample_repetition_penalty(
ctx, &candidates_data,
ctx, candidates_data,
penalized_tokens.data(), penalized_tokens.size(),
opt.repeat_penalty);
llama_sample_top_k(ctx, &candidates_data, opt.top_k);
llama_sample_top_p(ctx, &candidates_data, opt.top_p);
llama_sample_temperature(ctx, &candidates_data, opt.temp);
llama_token token_id= llama_sample_token(ctx, &candidates_data);
llama_sample_frequency_and_presence_penalties(
ctx, candidates_data,
penalized_tokens.data(), penalized_tokens.size(),
opt.frequency_penalty, opt.presence_penalty);

llama_token token_id = temperature_based_sample(candidates_data, ctx, opt);

// If the improbable happens, just use a newline token.
if (token_id == llama_token_eos()) {
Expand Down
Loading

0 comments on commit ac1c1f9

Please sign in to comment.