From a903409a5eb705294822ca55bd397faed88faa88 Mon Sep 17 00:00:00 2001 From: dungquixote42 <62397442+dungquixote42@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:26:25 -0500 Subject: [PATCH] fix adaptive p sampler rewinding too far back (#1359) * fix adaptive p sampler rewinding too far back * update comments * correct default value for total_weight, more comments * new variables/names * update comment for n_rewind * move null pointer check back to common_sampler_review() * refactor weighted_sum and total_weight to vector, better boundary check in llama_review_adaptive_p_impl() --- common/sampling.cpp | 16 +++---- common/sampling.h | 6 +-- examples/server/server-context.cpp | 6 ++- include/llama.h | 2 +- src/llama-sampling.cpp | 73 ++++++++++++++++++++++-------- src/llama-sampling.h | 11 ++--- src/llama.cpp | 4 +- 7 files changed, 75 insertions(+), 43 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 4c4a1371..8d6c9a1a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -106,6 +106,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } + result->n_rewind = -1; + return result; } @@ -143,16 +145,12 @@ void common_sampler_reset(common_sampler * ctx) { } void common_sampler_review(common_sampler * ctx) { - if (!ctx->adapt_p_ctx) { - return; + const int32_t n_rewind = ctx->n_rewind; + + // add stateful samplers here + if (ctx->adapt_p_ctx != nullptr) { + llama_review_adaptive_p(ctx->adapt_p_ctx, n_rewind); } - const bool record = ctx->record_samplers; - const bool rewind = ctx->rewind_samplers; - - llama_review_adaptive_p(ctx->adapt_p_ctx, record, rewind); - - ctx->record_samplers = false; - ctx->rewind_samplers = false; } void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) { diff --git a/common/sampling.h b/common/sampling.h index 27bf61db..66d9e613 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -135,8 +135,7 @@ struct common_sampler { std::mt19937 rng; - bool record_samplers = false; // record current state for stateful samplers - bool rewind_samplers = false; // rewind stateful samplers to last recorded + int32_t n_rewind; // number of tokens to rewind }; @@ -152,8 +151,7 @@ void common_sampler_free(struct common_sampler * ctx); void common_sampler_reset(common_sampler * ctx); // Review stateful samplers -// | record current state for rewinding -// | rewind to last recorded state +// - rewind internal states (maybe) void common_sampler_review(common_sampler * ctx); // Set the sampler seed diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index ea99b068..c46fce38 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -3332,6 +3332,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ bool next_token = has_next_token(result, slot); bool send_result = slot.token_buffer.size() >= slot.n_buffer || !next_token; int32_t n_rewind = 0; + bool sent_results = false; // don't restore if last time was also rewind if (!slot.rewind_status) { slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias @@ -3343,7 +3344,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) { rewind_context(slot, n_rewind); slot.rewind_status = true; - slot.ctx_sampling->rewind_samplers = true; } else if (send_result) { slot.rewind_status = false; @@ -3356,12 +3356,14 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ // send 1 token send_token_results(slot.token_buffer, slot, 1); } - slot.ctx_sampling->record_samplers = true; + sent_results = true; } else { // buffer the result slot.sampled = result.tok; // for common batch add } + + slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind; } void server_context::process_batch_tokens(int32_t & n_batch) { diff --git a/include/llama.h b/include/llama.h index 62e0dcdb..fb14d59e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1415,7 +1415,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); - void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind); + void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9191ba41..c5f808ba 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1053,20 +1053,48 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& // adaptive p -void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) { - if (record && rewind) { - LLAMA_LOG_WARN("%s: record AND rewind is invalid\n", __func__); +void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) { + if ((n_rewind == 0) || (adapt_p_ctx->target < 0.0f)) { return; } - if (record) { - adapt_p_ctx->recd_weighted_sum = adapt_p_ctx->weighted_sum; - adapt_p_ctx->recd_total_weight = adapt_p_ctx->total_weight; + // auto & weighted_sum = adapt_p_ctx->weighted_sum; + // auto & total_weight = adapt_p_ctx->total_weight; + + const int32_t sz = adapt_p_ctx->history.size(); + if ((sz <= 0) || (sz <= n_rewind)) { + // critically short history. reset to initial state + LLAMA_LOG_WARN("%s: sz=%d, n_rewind=%d should not be possible\n", __func__, sz, n_rewind); + adapt_p_ctx->history.clear(); + adapt_p_ctx->history.push_back({ + adapt_p_ctx->target / adapt_p_ctx->decay, // weighted_sum + 1.0f / adapt_p_ctx->decay }); // total_weight return; } - if (rewind) { - adapt_p_ctx->weighted_sum = adapt_p_ctx->recd_weighted_sum; - adapt_p_ctx->total_weight = adapt_p_ctx->recd_total_weight; - return; + + if (n_rewind < 0) { + // clear history except most recent + adapt_p_ctx->history.front() = adapt_p_ctx->history.back(); + adapt_p_ctx->history.resize(1); + } else { + // rewind + adapt_p_ctx->history.resize(sz - n_rewind); + + // int32_t sz = weighted_sum.size() - n_rewind; + // if (sz > 0) { + // weighted_sum.resize(sz); + // } else { + // LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz); + // weighted_sum.clear(); + // weighted_sum.push_back(adapt_p_ctx->target / adapt_p_ctx->decay); // set to default value + // } + // sz = total_weight.size() - n_rewind; + // if (sz > 0) { + // total_weight.resize(sz); + // } else { + // LLAMA_LOG_WARN("%s: n_rewind=%d, sz=%d should not be possible\n", __func__, n_rewind, sz); + // total_weight.clear(); + // total_weight.push_back(1.0f / adapt_p_ctx->decay); // set to default value + // } } } @@ -1102,8 +1130,11 @@ llama_token llama_sample_token_adaptive_p_impl( ? candidates->data[idx].p / ctx->cum_cur_p : ctx->orig_prob[id] / ctx->cum_orig_prob; if (update_prob > 0) { - ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; - ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; + ctx->history.push_back({ + ctx->decay * ctx->history.back().first + update_prob, // weighted_sum + ctx->decay * ctx->history.back().second + 1.0f }); // total_weight + // ctx->weighted_sum.push_back(ctx->decay * ctx->weighted_sum.back() + update_prob); + // ctx->total_weight.push_back(ctx->decay * ctx->total_weight.back() + 1.0f); } smpl->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -1138,10 +1169,12 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_ adapt_p_ctx->cum_cur_p = cum_sum; // compute adapted target probability + const float weighted_sum = adapt_p_ctx->history.back().first; + const float total_weight = adapt_p_ctx->history.back().second; const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f); - const float adapted_target = std::clamp(adapt_p_ctx->total_weight == 0.0f + const float adapted_target = std::clamp(total_weight == 0.0f ? target - : 2.0f * target - (adapt_p_ctx->weighted_sum / adapt_p_ctx->total_weight), + : 2.0f * target - (weighted_sum / total_weight), 0.0f, 1.0f); // transformation constants @@ -1202,16 +1235,20 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, /* .decay = */ clamped_decay, /* .updt_w_cur = */ updt_w_cur, /* .rng = */ std::mt19937(seed), - /* .weighted_sum = */ target / (1.0f - clamped_decay), - /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + // /* .weighted_sum = */ {}, + // /* .total_weight = */ {}, + /* .history = */ {}, /* .orig_prob = */ {}, /* .cum_orig_prob = */ 1.0f, /* .cum_cur_p = */ 1.0f, /* .max_xform_logit = */ -INFINITY, /* .cum_probs = */ {}, - /* .recd_weighted_sum = */ target / (1.0f - clamped_decay), - /* .recd_total_weight = */ 1.0f / (1.0f - clamped_decay), }; + // result->weighted_sum.push_back(target / (1.0f - clamped_decay)); + // result->total_weight.push_back(1.0f / (1.0f - clamped_decay)); + result->history.push_back({ + target / (1.0f - clamped_decay), // weighted_sum + 1.0f / (1.0f - clamped_decay) }); // total_weight result->orig_prob.resize(n_vocab); return result; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 2b52a412..6127a50f 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -70,8 +70,9 @@ struct llama_sampler_adaptive_p { const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99) const bool updt_w_cur; // false=original, true=current std::mt19937 rng; // RNG - float weighted_sum; // sum(p_n * decay^N) - float total_weight; // sum(decay^i), converges to 1/(1-decay) + // std::vector weighted_sum; // [0] = sum(p_n * decay^N) + // std::vector total_weight; // [0] = sum(decay^i), converges to 1/(1-decay) + std::vector> history; // // first referenced in prep std::vector orig_prob; // for storing the original proibabilities @@ -83,10 +84,6 @@ struct llama_sampler_adaptive_p { // first referenced in sample_token std::vector cum_probs; // cumulative probability distribution - - // recorded states for rewinding - float recd_weighted_sum; - float recd_total_weight; }; struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, @@ -105,7 +102,7 @@ void llama_sample_adaptive_p_impl( llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx); -void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind); +void llama_review_adaptive_p_impl(llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind); void llama_sample_repetition_penalties_impl( diff --git a/src/llama.cpp b/src/llama.cpp index 22ecc4a4..825800bb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8304,8 +8304,8 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed); } -void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const bool record, const bool rewind) { - llama_review_adaptive_p_impl(adapt_p_ctx, record, rewind); +void llama_review_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx, const int32_t n_rewind) { + llama_review_adaptive_p_impl(adapt_p_ctx, n_rewind); }