Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling : add XTC sampler #9742

Merged
merged 49 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
89640b0
Initial XTC commit
MaggotHATE Oct 4, 2024
9455194
Cleanup
MaggotHATE Oct 4, 2024
db54ac5
Simplified chances calculation
MaggotHATE Oct 4, 2024
41e1665
First fixes by comments
MaggotHATE Oct 4, 2024
d9c9203
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 4, 2024
f2a2a61
Fixed trailing backspaces
MaggotHATE Oct 4, 2024
4f8e55b
Fixed RNG to be reproduceable
MaggotHATE Oct 4, 2024
6d94ba2
Fixed forgotten header
MaggotHATE Oct 4, 2024
49cd211
Moved `min_keep`
MaggotHATE Oct 4, 2024
899e073
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 4, 2024
74f657c
Fixed broken randomization
MaggotHATE Oct 4, 2024
59e8e63
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 5, 2024
63e60de
Swapped sorting for a custom algorithm
MaggotHATE Oct 5, 2024
094caea
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 6, 2024
39940e5
Algorithm rework
MaggotHATE Oct 6, 2024
4c44e3d
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 7, 2024
dbe9ef7
Added XTC to `test-sampling`
MaggotHATE Oct 7, 2024
98b204c
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 7, 2024
8110f78
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 8, 2024
81a0c26
Simplified algorithm and more tests
MaggotHATE Oct 8, 2024
09bc6d5
Updated info in common and args
MaggotHATE Oct 8, 2024
c19fb26
Merged back lost commits in common and arg
MaggotHATE Oct 8, 2024
6feb6b3
Update dump info in common
MaggotHATE Oct 8, 2024
d0b1053
Fixed incorrect min_keep check
MaggotHATE Oct 8, 2024
ed535bb
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 9, 2024
37e02e3
Added XTC to README
MaggotHATE Oct 9, 2024
ba29d31
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 10, 2024
2107882
Renamed parameters, fixed info and defaults
MaggotHATE Oct 10, 2024
f7a383f
Initial server support
MaggotHATE Oct 10, 2024
72db625
Added XTC to server UIs
MaggotHATE Oct 10, 2024
882a603
Merge branch 'master' into master
MaggotHATE Oct 11, 2024
3968369
Fixed labels in old server UI
MaggotHATE Oct 11, 2024
acada1a
Made algorithm safer and more readable
MaggotHATE Oct 11, 2024
dfe587a
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 11, 2024
9c43a01
Removed xtc_threshold_max
MaggotHATE Oct 12, 2024
68557eb
Merge branch 'master' of https://github.com/MaggotHATE/llama.cpp-xtc
MaggotHATE Oct 12, 2024
ea85a51
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 12, 2024
cca842f
Fixed arg after update
MaggotHATE Oct 12, 2024
ea62e65
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 13, 2024
44bbd63
Quick fixes by comments
MaggotHATE Oct 14, 2024
a3e6522
Merge branch 'master' of https://github.com/MaggotHATE/llama.cpp-xtc
MaggotHATE Oct 14, 2024
dfef2c4
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 14, 2024
436a991
Simplified algorithm since threshold_max is removed
MaggotHATE Oct 14, 2024
3613a6d
Renamed random distribution
MaggotHATE Oct 14, 2024
17ad143
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 14, 2024
2be814a
Fixed tests and outdated README
MaggotHATE Oct 15, 2024
28d2cff
Merge branch 'master' of https://github.com/MaggotHATE/llama.cpp-xtc
MaggotHATE Oct 15, 2024
3496f58
Small fixes
MaggotHATE Oct 15, 2024
050eb7a
Merge branch 'ggerganov:master' into master
MaggotHATE Oct 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.tfs_z = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
[](common_params & params, const std::string & value) {
params.sparams.xtc_probability = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-threshold"}, "N",
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sparams.xtc_threshold = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
Expand Down
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,8 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
Expand Down
6 changes: 6 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TFS_Z = 4,
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,

};

// dimensionality reduction methods, used by cvector-generator
Expand All @@ -108,6 +110,8 @@ struct common_sampler_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float tfs_z = 1.00f; // 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
Expand All @@ -124,12 +128,14 @@ struct common_sampler_params {
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics


std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TFS_Z,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE
};

Expand Down
13 changes: 10 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {

snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
top_k, tfs_z, top_p, min_p, typ_p, temp,
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau);

return std::string(result);
Expand Down Expand Up @@ -184,6 +184,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TFS_Z:
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
break;
Expand Down Expand Up @@ -372,6 +375,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
default : return '?';
}
}
Expand All @@ -384,6 +388,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
default : return "";
}
}
Expand All @@ -396,6 +401,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
};

// since samplers names are written multiple ways
Expand Down Expand Up @@ -441,7 +447,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
};

std::vector<common_sampler_type> samplers;
Expand Down
13 changes: 13 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres

Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`

### XTC Sampling

- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).

Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-p` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.

By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.

Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 -xtc-p 0.5`.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still necessary to pass an explicit sampler chain via --sampling-seq in order to activate XTC? I thought that it is now in the sampler chain by default, and disabled by having xtc_probability set to 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@p-e-w While XTC is included into the sampler queue by default, it is put after all other truncating samplers. As such, the recommended combinations of samplers, as per your words in oobabooga/text-generation-webui#6335 , requires passing samplers chain explicitly.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it. Seems like the order is correct by default (XTC after truncation, which Min-P is). And all samplers are set to neutral (off) by default, right? So what does --sampling-seq mx do that wouldn't happen otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@p-e-w Not all samplers - Top K is 40 by default and Top P is 0.95. So, either they need to be set to 0 and 1.0 respectively, or (which is easier and more logical) sampling queue should be limited to only the samplers we need.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮 I had no idea! So llama.cpp users are getting crap samplers from the Stone Age without even realizing it. That's terrible. I would have expected a clean slate that samples from the raw model distribution unless parameters are set explicitly.

Anyway, this means your example command is correct, of course.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, this topic has already been brought up some time ago in other PRs. However, since in most cases llama.cpp is used as a library through another app or as a server, this issue is mostly related to llama-cli users. You can look at index-new.html (new server UI): it has different "default" values with top_k and top_p turned off, and I assume any other frontend will have a payload with all parameters set as needed.

But yes, this is an issue.

Copy link
Contributor

@strawberrymelonpanda strawberrymelonpanda Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llama-cli user here via scripts... thanks for bringing this up at any rate, Improved my (custom) benchmark scores just by adjusting settings to disable Top K and Top P.

I've seen the default params adjusted a few times for llama-cli; feels like this would be a good change.


Example usage: `-xtc-p 0.5 -xtc-t 0.1

### Logit Bias

- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
Expand Down
6 changes: 6 additions & 0 deletions examples/server/public/index-new.html
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
top_k: 0, // <= 0 to use vocab size
top_p: 1.0, // 1.0 = disabled
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
xtc_probability: 0.0, // 0 = disabled;
xtc_threshold: 0.1, // > 0.5 disables XTC;
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled
Expand Down Expand Up @@ -836,6 +838,8 @@
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
</fieldset>

Expand Down Expand Up @@ -1132,6 +1136,8 @@ <h2>llama.cpp</h2>
const snapSettings = {
temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
Expand Down
4 changes: 4 additions & 0 deletions examples/server/public/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@
top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled
xtc_probability: 0.0, // 0 = disabled;
xtc_threshold: 0.1, // > 0.5 disables XTC;
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled
Expand Down Expand Up @@ -1013,6 +1015,8 @@
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
</fieldset>
<hr />
<fieldset class="three">
Expand Down
4 changes: 4 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ struct server_context {
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
Expand Down Expand Up @@ -1196,6 +1198,8 @@ struct server_context {
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"xtc_probability", slot.sparams.xtc_probability},
{"xtc_threshold", slot.sparams.xtc_threshold},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,9 @@ extern "C" {
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);

/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);

/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
Expand Down
95 changes: 95 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
};
}

// xtc

struct llama_sampler_xtc {
const float probability;
const float threshold;
const size_t min_keep;

const uint32_t seed;
uint32_t seed_cur;

std::mt19937 rng;
};

static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
return "xtc";
}

static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_xtc *) smpl->ctx;

if (ctx->probability <= 0.0f
|| ctx->threshold > 0.5f
|| cur_p->size < 2) {
return;
}

std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
float chance = distribution(ctx->rng);
if (chance > ctx->probability) return;

// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);

int pos_last = 0;

for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p - ctx->threshold >= -1e-5) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this epsilon instead of a regular comparison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it after running tests - they were failing due to precision problem.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, no. Never change correct code to satisfy tests. If the code is semantically correct, and the tests don't pass, the tests need to be adapted to account for things such as floating point shenanigans. But changing the code itself is always wrong, unless of course the code has a bug.

The correct way to express this condition is

if (cur_p->data[i].p >= ctx->threshold) {

Nothing else will do. And the tests need to work with that, or the tests are wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it should be considered a problem with tests (precision problem is a wider topic), but alright.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My basic point is that the tests serve the code, not the other way round. Code is only ever changed in response to failing tests if the code is found to have a bug.

I had a similar problem with tests for an experimental sampler I wrote for llama.cpp a while ago, and I was able to work around it by using a different set of token probabilities in the tests that I constructed specifically so that after probability renormalization, the resulting values were exactly representable in floating point.

pos_last = i;
} else break;
}

if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
cur_p->data += pos_last;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may potentially break 3rd party code that expects this pointer to be unchanged (eg. to free it after sampling). I don't think this is necessarily a problem, but we should make it clear that this pointer may be changed by the samplers, and applications should not rely on it being unchanged.

cur_p->size = cur_p->size - pos_last;
}
}

static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);

// copy the state
{
auto * result_ctx = (llama_sampler_xtc *) result->ctx;

result_ctx->rng = ctx->rng;
}

return result;
}

static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
delete (llama_sampler_xtc *) smpl->ctx;
}

static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK this is necessary to properly reset seed and maintain repeatability, as recommended by @slaren earlier.

auto * ctx = (llama_sampler_xtc *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static struct llama_sampler_i llama_sampler_xtc_i = {
/* .name = */ llama_sampler_xtc_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sample_xtc_apply,
/* .reset = */ llama_sampler_xtc_reset,
/* .clone = */ llama_sampler_xtc_clone,
/* .free = */ llama_sampler_xtc_free,
};

struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
/* .threshold = */ t,
/* .min_keep = */ min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}

// mirostat

struct llama_sampler_mirostat {
Expand Down
Loading
Loading