mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
server: revert checkpoint fix (#1716)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
parent
9c7d8b07cc
commit
9f1deefa71
@ -3215,22 +3215,19 @@ void server_context::create_checkpoint_at_interval(server_slot & slot, const gp
|
||||
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
||||
bool created = create_checkpoint(slot);
|
||||
if (created) {
|
||||
slot.checkpoint_pos = pos;
|
||||
slot.checkpoint_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::apply_checkpoint(server_slot & slot) {
|
||||
llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
|
||||
const bool has_recurrent = llama_model_has_recurrent(llama_get_model(slot.ctx));
|
||||
// For hybrid/recurrent models, pos_min semantics don't apply: the recurrent state is a single
|
||||
// snapshot, not a per-token window. Use pos_max against n_past to match whole-prefix checkpoints.
|
||||
const auto pos_min_thold = std::max(0, pos_next - 1);
|
||||
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
|
||||
if (has_recurrent || pos_min > pos_min_thold) {
|
||||
if (pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
|
||||
|
||||
// search for a context checkpoint
|
||||
@ -3238,13 +3235,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
slot.server_cached_prompt.checkpoints.rbegin(),
|
||||
slot.server_cached_prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
if (has_recurrent) {
|
||||
// recurrent/hybrid: only whole-prefix checkpoints are valid; pick the latest one
|
||||
// that covers no more than the current n_past and still leaves tokens to decode.
|
||||
return cur.pos_max <= slot.n_past && cur.pos_max < pos_next;
|
||||
}
|
||||
return cur.pos_min < pos_min_thold;
|
||||
return cur.pos_min < pos_min_thold || cur.pos_min == 0;
|
||||
}
|
||||
);
|
||||
|
||||
@ -3270,33 +3261,22 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
if (has_recurrent) {
|
||||
// Without a usable recurrent checkpoint, preserving prefix state leaks stale recurrent memory
|
||||
// from prior requests into the current prompt. Force a full prompt re-processing fallback.
|
||||
SLT_WRN(slot, "%s", "no usable hybrid/recurrent checkpoint; forcing full prompt re-processing\n");
|
||||
slot.n_past = 0;
|
||||
slot.n_past_prompt = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
common_sampler_reset(slot.ctx_sampling);
|
||||
} else {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
slot.n_past = 0;
|
||||
slot.n_past_prompt = 0;
|
||||
}
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
slot.n_past = 0;
|
||||
slot.n_past_prompt = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
common_sampler_reset(slot.ctx_sampling);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// erase checkpoints that are no longer consistent with the current decode position.
|
||||
// Transformer: anything with pos_min beyond the threshold is stale.
|
||||
// Recurrent/hybrid: anything with pos_max past pos_next refers to future tokens we've rewound past.
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
const bool stale = has_recurrent ? (cur.pos_max > pos_next) : (cur.pos_min > pos_min_thold);
|
||||
if (stale) {
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
it = slot.server_cached_prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user