Skip to content

Commit

Permalink
DRY: Fixed crash issue due to DRY being in chain but uninitialized
Browse files Browse the repository at this point in the history
  • Loading branch information
wwoodsTM committed Oct 24, 2024
1 parent 1303893 commit dc408bb
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/

static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;
if (ctx->dry_penalty_last_n == 0) {
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
return;
}

Expand All @@ -1758,7 +1758,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;

if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f) {
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
return;
}

Expand Down Expand Up @@ -1999,18 +1999,15 @@ static struct llama_sampler_i llama_sampler_dry_i = {
};

struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
if (dry_multiplier == 0.0f || dry_base < 1.0f) {
return nullptr;
}

int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);

// Process sequence breakers
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40;
const int MAX_SEQ_LEN = 20;

if (seq_breakers != nullptr && num_breakers > 0) {
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);

if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
// Process sequence breakers
for (size_t i = 0; i < num_breakers; ++i) {
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
Expand Down Expand Up @@ -2041,9 +2038,9 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
/* .dry_allowed_length = */ dry_allowed_length,
/* .dry_penalty_last_n = */ dry_penalty_last_n,
/* .dry_processed_breakers = */ std::move(processed_breakers),
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
Expand All @@ -2052,11 +2049,6 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
llama_vocab dummy_vocab;
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);

if (result == nullptr) {
return nullptr;
}

auto * ctx = (llama_sampler_dry *) result->ctx;

// Process the token-based sequence breakers
Expand Down

0 comments on commit dc408bb

Please sign in to comment.