diff --git a/README.md b/README.md index c4cb690..a821208 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,16 @@ Remember, the recent chat content is just a rolling prompt concatenated to the e - Repeat penalty. - `/repeat_penalty 1.2` sets the repeated token penalty. - `/repeat_window 20` penalizes the most recent 20 tokens from being generated. + - `/frequency_penalty 0.1` sets the frequency penalty. (0.0 is default, off) + - `/presence_penalty 0.1` sets presence penalty. (0.0 is default, off) - `/less= some unwanted words` adds extra tokens to be penalized. - `/dropless` clears the extra penalized tokens list. - Generation parameters. - `/temp 0.7` sets the temperature. - `/top_k 40` sets the `top_k` parameter. - `/top_p 0.9` sets the `top_p` parameter. + - `/tfs_z 0.9` sets Tail Free Sampling cutoff. (1.0 is default, off) + - `/typical_p 0.9` sets the Locally Typical Sampling cutoff. (1.0 is default, off) - Execution parameters. - `/thread_count 8` sets the number of threads. - `/batch_count 8` sets the batch size. diff --git a/src/chat/chat.cc b/src/chat/chat.cc index dab1f99..21cc1ee 100644 --- a/src/chat/chat.cc +++ b/src/chat/chat.cc @@ -105,6 +105,22 @@ rendezllama::augment_chat_input( } } +static + llama_token +temperature_based_sample( + llama_token_data_array* candidates_data, + struct llama_context* ctx, + const rendezllama::ChatOptions& opt) +{ + const unsigned keep_one = 1; + llama_sample_top_k(ctx, candidates_data, opt.top_k, keep_one); + llama_sample_tail_free(ctx, candidates_data, opt.tfs_z, keep_one); + llama_sample_typical(ctx, candidates_data, opt.typical_p, keep_one); + llama_sample_top_p(ctx, candidates_data, opt.top_p, keep_one); + llama_sample_temperature(ctx, candidates_data, opt.temp); + return llama_sample_token(ctx, candidates_data); +} + llama_token rendezllama::generate_next_token( struct llama_context* ctx, @@ -141,18 +157,20 @@ rendezllama::generate_next_token( i, logits[i], 0.0f, }; } - llama_token_data_array candidates_data = { + llama_token_data_array candidates_data[1] = {{ candidates.data(), candidates.size(), false, - }; + }}; llama_sample_repetition_penalty( - ctx, &candidates_data, + ctx, candidates_data, penalized_tokens.data(), penalized_tokens.size(), opt.repeat_penalty); - llama_sample_top_k(ctx, &candidates_data, opt.top_k); - llama_sample_top_p(ctx, &candidates_data, opt.top_p); - llama_sample_temperature(ctx, &candidates_data, opt.temp); - llama_token token_id= llama_sample_token(ctx, &candidates_data); + llama_sample_frequency_and_presence_penalties( + ctx, candidates_data, + penalized_tokens.data(), penalized_tokens.size(), + opt.frequency_penalty, opt.presence_penalty); + + llama_token token_id = temperature_based_sample(candidates_data, ctx, opt); // If the improbable happens, just use a newline token. if (token_id == llama_token_eos()) { diff --git a/src/chat/opt.cc b/src/chat/opt.cc index 4970d87..7b60a8d 100644 --- a/src/chat/opt.cc +++ b/src/chat/opt.cc @@ -57,6 +57,25 @@ parse_quoted_string(FildeshX* in) return s; } +static + bool +maybe_parse_bool_option( + bool* b, FildeshX* in, + const char* name) +{ + int tmp_b = 0; + if (!skipstr_FildeshX(in, name)) { + return false; + } + if (parse_int_FildeshX(in, &tmp_b)) { + *b = (tmp_b != 0); + } + else { + fildesh_log_warning("Need a 1 or 0."); + } + return true; +} + static bool parse_options_sxproto( @@ -178,6 +197,11 @@ parse_options_sxproto( opt.transcript_sibling_filename = filename; opt.transcript_filename = parse_quoted_string(&slice); } + else if ( + maybe_parse_bool_option(&opt.mlock_on, &slice, "mlock_on") || + maybe_parse_bool_option(&opt.mmap_on, &slice, "mmap_on")) { + // Success! + } else if (rendezllama::maybe_parse_option_command(opt, &slice, nullout)) { // Success! } @@ -353,15 +377,23 @@ rendezllama::parse_options(rendezllama::ChatOptions& opt, int argc, char** argv) opt.command_prefix_char = argv[argi][0]; } else if (0 == strcmp("--thread_count", argv[argi])) { + int n = 0; argi += 1; - if (!fildesh_parse_int(&opt.thread_count, argv[argi]) || opt.thread_count <= 0) { + if (fildesh_parse_int(&n, argv[argi]) && n >= 0) { + opt.thread_count = n; + } + else { fildesh_log_error("--thread_count needs positive arg"); exstatus = 64; } } else if (0 == strcmp("--batch_count", argv[argi])) { + int n = 0; argi += 1; - if (!fildesh_parse_int(&opt.batch_count, argv[argi]) || opt.batch_count <= 0) { + if (fildesh_parse_int(&n, argv[argi]) && n >= 0) { + opt.batch_count = n; + } + else { fildesh_log_error("--batch_count needs positive arg"); exstatus = 64; } @@ -389,79 +421,19 @@ rendezllama::parse_options(rendezllama::ChatOptions& opt, int argc, char** argv) } } else if (0 == strcmp("--context_token_limit", argv[argi])) { + int n = 0; argi += 1; - if (!fildesh_parse_int(&opt.context_token_limit, argv[argi])) { - fildesh_log_error("--context_token_limit needs int"); - exstatus = 64; - } - else if (opt.context_token_limit > 2048) { - fildesh_log_warning("--context_token_limit is above 2048. Expect poor results."); - } - } - else if (0 == strcmp("--sentence_limit", argv[argi])) { - argi += 1; - if (!fildesh_parse_int(&opt.sentence_limit, argv[argi])) { - fildesh_log_error("--sentence_limit needs int"); - exstatus = 64; - } - } - else if (0 == strcmp("--sentence_token_limit", argv[argi])) { - argi += 1; - if (!fildesh_parse_int(&opt.sentence_token_limit, argv[argi])) { - fildesh_log_error("--sentence_token_limit needs int"); - exstatus = 64; - } - } - // Original stuff. - else if (0 == strcmp("--repeat_last_n", argv[argi]) || - 0 == strcmp("--repeat_window", argv[argi])) - { - argi += 1; - if (!fildesh_parse_int(&opt.repeat_last_count, argv[argi])) { - fildesh_log_error("--repeat_window needs int"); - exstatus = 64; - } - } - else if (0 == strcmp("--repeat_penalty", argv[argi])) { - argi += 1; - double f = 0; - if (!fildesh_parse_double(&f, argv[argi])) { - fildesh_log_error("--repeat_penalty needs float"); - exstatus = 64; - } - else { - opt.repeat_penalty = f; - } - } - else if (0 == strcmp("--temp", argv[argi])) { - argi += 1; - double f = 0; - if (!fildesh_parse_double(&f, argv[argi])) { - fildesh_log_error("--temp needs float"); - exstatus = 64; + if (fildesh_parse_int(&n, argv[argi]) && n > 0) { + opt.context_token_limit = n; + if (opt.context_token_limit > 2048) { + fildesh_log_warning("--context_token_limit is above 2048. Expect poor results."); + } } else { - opt.temp = f; - } - } - else if (0 == strcmp("--top_k", argv[argi])) { - argi += 1; - if (!fildesh_parse_int(&opt.top_k, argv[argi])) { - fildesh_log_error("--top_k needs int"); + fildesh_log_error("--context_token_limit needs positive arg"); exstatus = 64; } } - else if (0 == strcmp("--top_p", argv[argi])) { - argi += 1; - double f = 0; - if (!fildesh_parse_double(&f, argv[argi])) { - fildesh_log_error("--top_p needs float"); - exstatus = 64; - } - else { - opt.top_p = f; - } - } else { exstatus = 64; } @@ -496,150 +468,111 @@ rendezllama::parse_options(rendezllama::ChatOptions& opt, int argc, char** argv) return exstatus; } +static bool -rendezllama::maybe_parse_option_command( - rendezllama::ChatOptions& opt, - FildeshX* in, - std::ostream& eout) +maybe_parse_float( + float* f, FildeshX* in, std::ostream& out, + const char* name, + const char* command_delim_chars) { - if (skipstr_FildeshX(in, "repeat_last_n") || - skipstr_FildeshX(in, "repeat_last_count") || - skipstr_FildeshX(in, "repeat_window")) - { - int n = -1; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "repeat_window=" << opt.repeat_last_count << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n) && n >= 0) { - opt.repeat_last_count = n; - } - else { - fildesh_log_warning("Need an int."); - } + double tmp_f = 0; + if (!skipstr_FildeshX(in, name)) { + return false; } - else if (skipstr_FildeshX(in, "repeat_penalty")) { - double f = 0; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "repeat_penalty=" << opt.repeat_penalty << '\n'; eout.flush(); - } - else if (parse_double_FildeshX(in, &f) && f >= 0) { - opt.repeat_penalty = f; - } - else { - fildesh_log_warning("Need a float."); - } + if (!skipchrs_FildeshX(in, command_delim_chars)) { + out << name << "=" << *f << '\n'; out.flush(); } - else if (skipstr_FildeshX(in, "temp")) { - double f = 0; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "temp=" << opt.temp << '\n'; eout.flush(); - } - else if (parse_double_FildeshX(in, &f) && f >= 0) { - opt.temp = f; - } - else { - fildesh_log_warning("Need a float."); - } + else if (parse_double_FildeshX(in, &tmp_f) && tmp_f >= 0) { + *f = (float) tmp_f; } - else if (skipstr_FildeshX(in, "top_k")) { - int n = -1; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "top_k=" << opt.top_k << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n) && n > 0) { - opt.top_k = n; - } - else { - fildesh_log_warning("Need an int."); - } + else { + fildesh_log_warning("Need a float."); } - else if (skipstr_FildeshX(in, "top_p")) { - double f = 0; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "top_p=" << opt.top_p << '\n'; eout.flush(); - } - else if (parse_double_FildeshX(in, &f) && f >= 0) { - opt.top_p = f; - } - else { - fildesh_log_warning("Need a float."); - } + return true; +} + +static + bool +maybe_parse_nat( + unsigned* n, FildeshX* in, std::ostream& out, + const char* name, + const char* command_delim_chars) +{ + int tmp_n = 0; + if (!skipstr_FildeshX(in, name)) { + return false; } - else if (skipstr_FildeshX(in, "thread_count")) { - int n = -1; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "thread_count=" << opt.thread_count << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n) && n > 0) { - opt.thread_count = n; - } - else { - fildesh_log_warning("Need a positive int."); - } + if (!skipchrs_FildeshX(in, command_delim_chars)) { + out << name << "=" << *n << '\n'; out.flush(); } - else if (skipstr_FildeshX(in, "batch_count")) { - int n = -1; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "batch_count=" << opt.batch_count << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n) && n > 0) { - opt.batch_count = n; - } - else { - fildesh_log_warning("Need a positive int."); - } + else if (parse_int_FildeshX(in, &tmp_n) && tmp_n >= 0) { + *n = (unsigned) tmp_n; } - else if (skipstr_FildeshX(in, "mlock_on")) { - int n = 0; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "mlock_on=" << (opt.mlock_on ? 1 : 0) << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n)) { - opt.mlock_on = (n != 0); - } - else { - fildesh_log_warning("Need a 1 or 0."); - } + else { + fildesh_log_warning("Need a non-negative int."); } - else if (skipstr_FildeshX(in, "mmap_on")) { - int n = 0; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "mmap_on=" << (opt.mmap_on ? 1 : 0) << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n)) { - opt.mmap_on = (n != 0); - } - else { - fildesh_log_warning("Need a 1 or 0."); - } + return true; +} + +static + bool +maybe_parse_positive( + unsigned* n, FildeshX* in, std::ostream& out, + const char* name, + const char* command_delim_chars) +{ + int tmp_n = 0; + if (!skipstr_FildeshX(in, name)) { + return false; } - else if (skipstr_FildeshX(in, "sentence_limit")) { - int n = -1; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "sentence_limit=" << opt.sentence_limit << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n) && n >= 0) { - opt.sentence_limit = n; - } - else { - fildesh_log_warning("Need a non-negative int."); - } + if (!skipchrs_FildeshX(in, command_delim_chars)) { + out << name << "=" << *n << '\n'; out.flush(); } - else if (skipstr_FildeshX(in, "sentence_token_limit")) - { - int n = -1; - if (!skipchrs_FildeshX(in, opt.command_delim_chars)) { - eout << "sentence_token_limit=" << opt.sentence_token_limit << '\n'; eout.flush(); - } - else if (parse_int_FildeshX(in, &n) && n >= 0) { - opt.sentence_token_limit = n; - } - else { - fildesh_log_warning("Need a non-negative int."); - } + else if (parse_int_FildeshX(in, &tmp_n) && tmp_n > 0) { + *n = (unsigned) tmp_n; + } + else { + fildesh_log_warning("Need a positive int."); + } + return true; +} + + bool +rendezllama::maybe_parse_option_command( + rendezllama::ChatOptions& opt, + FildeshX* in, + std::ostream& out) +{ + const char* const delims = opt.command_delim_chars; + if ( + maybe_parse_float(&opt.frequency_penalty, in, out, "frequency_penalty", delims) || + maybe_parse_float(&opt.presence_penalty, in, out, "presence_penalty", delims) || + maybe_parse_float(&opt.repeat_penalty, in, out, "repeat_penalty", delims) || + maybe_parse_nat(&opt.repeat_last_count, in, out, "repeat_window", delims) || + maybe_parse_nat(&opt.repeat_last_count, in, out, "repeat_last_n", delims) || + maybe_parse_nat(&opt.repeat_last_count, in, out, "repeat_last_count", delims)) { + // Success! + } + else if ( + maybe_parse_positive(&opt.top_k, in, out, "top_k", delims) || + maybe_parse_float(&opt.top_p, in, out, "top_p", delims) || + maybe_parse_float(&opt.tfs_z, in, out, "tfs_z", delims) || + maybe_parse_float(&opt.typical_p, in, out, "typical_p", delims) || + maybe_parse_float(&opt.temp, in, out, "temp", delims)) { + // Success! + } + else if ( + maybe_parse_positive(&opt.thread_count, in, out, "thread_count", delims) || + maybe_parse_positive(&opt.batch_count, in, out, "batch_count", delims)) { + // Success! + } + else if ( + maybe_parse_nat(&opt.sentence_limit, in, out, "sentence_limit", delims) || + maybe_parse_nat(&opt.sentence_token_limit, in, out, "sentence_token_limit", delims)) { + // Success! } else if (skipstr_FildeshX(in, "opt")) { - print_options(eout, opt); + print_options(out, opt); } else { return false; diff --git a/src/chat/opt.hh b/src/chat/opt.hh index 194517f..fbd1495 100644 --- a/src/chat/opt.hh +++ b/src/chat/opt.hh @@ -33,16 +33,20 @@ struct ChatOptions { char command_prefix_char = '/'; const char command_delim_chars[5] = ":=! "; - int thread_count = 1; - int sentence_limit = 3; - int sentence_token_limit = 50; - int top_k = 1000; + unsigned thread_count = 1; + unsigned sentence_limit = 3; + unsigned sentence_token_limit = 50; + unsigned top_k = 1000; float top_p = 0.95; float temp = 0.7; + float tfs_z = 1.0; + float typical_p = 1.0; + float frequency_penalty = 0.0; + float presence_penalty = 0.0; float repeat_penalty = 1.2; - int repeat_last_count = 20; - int context_token_limit = 2048; - int batch_count = 8; + unsigned repeat_last_count = 20; + unsigned context_token_limit = 2048; + unsigned batch_count = 8; int seed; bool mlock_on = false; bool mmap_on = true;