From 39b3a188e8f7f9144d12f6ef51c6e660c161c0e7 Mon Sep 17 00:00:00 2001 From: firecoperana <18252262+firecoperana@users.noreply.github.com> Date: Wed, 6 May 2026 00:42:21 -0500 Subject: [PATCH] server: fix mtmd checkpoint restore and avoid checkpoint host copies (#1743) Co-authored-by: firecoperana --- examples/server/server-common.cpp | 8 ++-- examples/server/server-context.cpp | 74 ++++++++++++++++++------------ examples/server/server-task.h | 4 ++ 3 files changed, 53 insertions(+), 33 deletions(-) diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp index a0ac7126..d269e6dc 100644 --- a/examples/server/server-common.cpp +++ b/examples/server/server-common.cpp @@ -1647,9 +1647,9 @@ llama_pos server_tokens::pos_next(int64_t n_tokens) const { } -size_t server_tokens::size_up_to_pos(llama_pos max_idx) const { +size_t server_tokens::size_up_to_pos(llama_pos max_pos) const { if (!has_mtmd) { - return std::min((size_t)max_idx+1, tokens.size()); + return std::min((size_t)max_pos, tokens.size()); } size_t idx = 0; @@ -1669,12 +1669,12 @@ size_t server_tokens::size_up_to_pos(llama_pos max_idx) const { idx++; } - if (idx >= max_idx) { + if (pos >= max_pos) { break; } } - return idx+1; + return idx; } diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index ab477ebf..e6df46fc 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -16,6 +16,28 @@ #include #include +static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) { + if (pos_min == -1) { + pos_min = llama_kv_cache_seq_pos_min(ctx, id); + } + if (pos_max == -1) { + pos_max = llama_kv_cache_seq_pos_max(ctx, id); + } + const size_t checkpoint_size = llama_state_seq_get_size(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + ckpt.pos_min = pos_min; + ckpt.pos_max = pos_max; + ckpt.pos_max_prompt = pos_max + offset; + ckpt.pos_min_prompt = pos_min + offset; + ckpt.n_tokens = n_tokens; + ckpt.data.resize(checkpoint_size); + + const size_t n = llama_state_seq_get_data(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } +} + static void log_text(const gpt_params & params_base, const std::string & text) { if (params_base.minilog) { LOG_TEE("%s\n", text.c_str()); @@ -2916,7 +2938,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl auto kv_past = slot.cache_tokens.pos_next(slot.n_past); int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); - llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard); + llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard)); llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard); if (slot.has_mtp && slot.spec) { common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past); @@ -3227,7 +3249,7 @@ void server_context::apply_checkpoint(server_slot & slot) { if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) { int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); - if (pos_min > pos_min_thold) { + if (pos_min >= pos_min_thold) { SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min); // search for a context checkpoint @@ -3248,15 +3270,17 @@ void server_context::apply_checkpoint(server_slot & slot) { const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float)checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1); - slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt)); - slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1); - SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + slot.n_past = slot.cache_tokens.size_up_to_pos(pos_next); + + pos_next = slot.prompt_tokens.pos_next(slot.n_past_prompt); + pos_next = std::min(pos_next, std::max(it->pos_min_prompt + 1, it->pos_max_prompt)); + slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(pos_next); + SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, it->n_tokens, slot.n_past, (float)checkpoint_size / 1024 / 1024); } } @@ -3267,6 +3291,7 @@ void server_context::apply_checkpoint(server_slot & slot) { slot.n_past_prompt = 0; slot.n_past_se = 0; slot.ga_i = 0; + pos_next = 0; common_sampler_reset(slot.ctx_sampling); } } @@ -3276,7 +3301,7 @@ void server_context::apply_checkpoint(server_slot & slot) { // erase any checkpoints with pos_min > pos_min_thold for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) { const auto & cur = *it; - if (cur.pos_min > pos_min_thold) { + if (cur.pos_max > pos_min_thold) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); it = slot.server_cached_prompt.checkpoints.erase(it); } else { @@ -3292,10 +3317,10 @@ bool server_context::create_checkpoint(server_slot & slot) { const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16); + do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.cache_tokens.n_tokens() >= 64); // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max); + do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || slot.cache_tokens.n_tokens() > slot.server_cached_prompt.checkpoints.back().n_tokens); if (do_checkpoint) { const int64_t t_start = ggml_time_us(); @@ -3303,26 +3328,17 @@ bool server_context::create_checkpoint(server_slot & slot) { // make room for the new checkpoint, if needed const auto & cur = slot.server_cached_prompt.checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024); slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); } + + auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(); + server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset); - const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.pos_min_prompt = */ pos_min + slot.n_past_offset, - /*.pos_max_prompt = */ pos_max + slot.n_past_offset , - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n", - (int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024, + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB, took %.2f ms)\n", + (int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024, (ggml_time_us() - t_start) / 1000.0); } return do_checkpoint; @@ -3885,7 +3901,7 @@ void server_context::speculative_decoding_accept() { llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data()); mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id); } - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); discard_speculative_checkpoint(slot, ctx); } @@ -4105,7 +4121,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { slot.n_past = slot.cache_tokens.n_tokens(); // Remove from KV cache - llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1); + llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); // Truncate buffer slot.token_buffer.resize(n_keep_buffer); diff --git a/examples/server/server-task.h b/examples/server/server-task.h index 0100de20..76a6bad3 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -353,6 +353,8 @@ struct server_prompt_checkpoint { llama_pos pos_min_prompt; llama_pos pos_max_prompt; + int64_t n_tokens; + std::vector data; size_t size() const { @@ -365,6 +367,7 @@ struct server_prompt_checkpoint { j["pos_max"] = pos_max; j["pos_min_prompt"] = pos_min_prompt; j["pos_max_prompt"] = pos_max_prompt; + j["n_tokens"] = n_tokens; return j; } @@ -373,6 +376,7 @@ struct server_prompt_checkpoint { pos_max = j.value("pos_max", 0); pos_min_prompt = j.value("pos_min_prompt", 0); pos_max_prompt = j.value("pos_max_prompt", 0); + n_tokens = j.value("n_tokens", 0); } };