server: fix mtmd checkpoint restore and avoid checkpoint host copies (#1743)

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana 2026-05-06 00:42:21 -05:00 committed by GitHub
parent e722f0bb73
commit 39b3a188e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 33 deletions

View File

@ -1647,9 +1647,9 @@ llama_pos server_tokens::pos_next(int64_t n_tokens) const {
}
size_t server_tokens::size_up_to_pos(llama_pos max_idx) const {
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
if (!has_mtmd) {
return std::min((size_t)max_idx+1, tokens.size());
return std::min((size_t)max_pos, tokens.size());
}
size_t idx = 0;
@ -1669,12 +1669,12 @@ size_t server_tokens::size_up_to_pos(llama_pos max_idx) const {
idx++;
}
if (idx >= max_idx) {
if (pos >= max_pos) {
break;
}
}
return idx+1;
return idx;
}

View File

@ -16,6 +16,28 @@
#include <regex>
#include <exception>
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) {
if (pos_min == -1) {
pos_min = llama_kv_cache_seq_pos_min(ctx, id);
}
if (pos_max == -1) {
pos_max = llama_kv_cache_seq_pos_max(ctx, id);
}
const size_t checkpoint_size = llama_state_seq_get_size(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
ckpt.pos_min = pos_min;
ckpt.pos_max = pos_max;
ckpt.pos_max_prompt = pos_max + offset;
ckpt.pos_min_prompt = pos_min + offset;
ckpt.n_tokens = n_tokens;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
}
static void log_text(const gpt_params & params_base, const std::string & text) {
if (params_base.minilog) {
LOG_TEE("%s\n", text.c_str());
@ -2916,7 +2938,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
if (slot.has_mtp && slot.spec) {
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
@ -3227,7 +3249,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
if (pos_min > pos_min_thold) {
if (pos_min >= pos_min_thold) {
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
// search for a context checkpoint
@ -3248,15 +3270,17 @@ void server_context::apply_checkpoint(server_slot & slot) {
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float)checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
slot.n_past = slot.cache_tokens.size_up_to_pos(pos_next);
pos_next = slot.prompt_tokens.pos_next(slot.n_past_prompt);
pos_next = std::min(pos_next, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(pos_next);
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, it->n_tokens, slot.n_past, (float)checkpoint_size / 1024 / 1024);
}
}
@ -3267,6 +3291,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
slot.n_past_prompt = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
pos_next = 0;
common_sampler_reset(slot.ctx_sampling);
}
}
@ -3276,7 +3301,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
// erase any checkpoints with pos_min > pos_min_thold
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
if (cur.pos_max > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
it = slot.server_cached_prompt.checkpoints.erase(it);
} else {
@ -3292,10 +3317,10 @@ bool server_context::create_checkpoint(server_slot & slot) {
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.cache_tokens.n_tokens() >= 64);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || slot.cache_tokens.n_tokens() > slot.server_cached_prompt.checkpoints.back().n_tokens);
if (do_checkpoint) {
const int64_t t_start = ggml_time_us();
@ -3303,26 +3328,17 @@ bool server_context::create_checkpoint(server_slot & slot) {
// make room for the new checkpoint, if needed
const auto & cur = slot.server_cached_prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024);
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
}
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset);
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB, took %.2f ms)\n",
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024,
(ggml_time_us() - t_start) / 1000.0);
}
return do_checkpoint;
@ -3885,7 +3901,7 @@ void server_context::speculative_decoding_accept() {
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
}
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
discard_speculative_checkpoint(slot, ctx);
}
@ -4105,7 +4121,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
slot.n_past = slot.cache_tokens.n_tokens();
// Remove from KV cache
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
// Truncate buffer
slot.token_buffer.resize(n_keep_buffer);

View File

@ -353,6 +353,8 @@ struct server_prompt_checkpoint {
llama_pos pos_min_prompt;
llama_pos pos_max_prompt;
int64_t n_tokens;
std::vector<uint8_t> data;
size_t size() const {
@ -365,6 +367,7 @@ struct server_prompt_checkpoint {
j["pos_max"] = pos_max;
j["pos_min_prompt"] = pos_min_prompt;
j["pos_max_prompt"] = pos_max_prompt;
j["n_tokens"] = n_tokens;
return j;
}
@ -373,6 +376,7 @@ struct server_prompt_checkpoint {
pos_max = j.value<llama_pos>("pos_max", 0);
pos_min_prompt = j.value<llama_pos>("pos_min_prompt", 0);
pos_max_prompt = j.value<llama_pos>("pos_max_prompt", 0);
n_tokens = j.value<int64_t>("n_tokens", 0);
}
};