mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
server: fix mtmd checkpoint restore and avoid checkpoint host copies (#1743)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
parent
e722f0bb73
commit
39b3a188e8
@ -1647,9 +1647,9 @@ llama_pos server_tokens::pos_next(int64_t n_tokens) const {
|
||||
}
|
||||
|
||||
|
||||
size_t server_tokens::size_up_to_pos(llama_pos max_idx) const {
|
||||
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
|
||||
if (!has_mtmd) {
|
||||
return std::min((size_t)max_idx+1, tokens.size());
|
||||
return std::min((size_t)max_pos, tokens.size());
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
@ -1669,12 +1669,12 @@ size_t server_tokens::size_up_to_pos(llama_pos max_idx) const {
|
||||
idx++;
|
||||
}
|
||||
|
||||
if (idx >= max_idx) {
|
||||
if (pos >= max_pos) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return idx+1;
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -16,6 +16,28 @@
|
||||
#include <regex>
|
||||
#include <exception>
|
||||
|
||||
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) {
|
||||
if (pos_min == -1) {
|
||||
pos_min = llama_kv_cache_seq_pos_min(ctx, id);
|
||||
}
|
||||
if (pos_max == -1) {
|
||||
pos_max = llama_kv_cache_seq_pos_max(ctx, id);
|
||||
}
|
||||
const size_t checkpoint_size = llama_state_seq_get_size(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
ckpt.pos_min = pos_min;
|
||||
ckpt.pos_max = pos_max;
|
||||
ckpt.pos_max_prompt = pos_max + offset;
|
||||
ckpt.pos_min_prompt = pos_min + offset;
|
||||
ckpt.n_tokens = n_tokens;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data(ctx, ckpt.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
}
|
||||
|
||||
static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
if (params_base.minilog) {
|
||||
LOG_TEE("%s\n", text.c_str());
|
||||
@ -2916,7 +2938,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
|
||||
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard));
|
||||
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
||||
if (slot.has_mtp && slot.spec) {
|
||||
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
||||
@ -3227,7 +3249,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
||||
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
if (pos_min >= pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
|
||||
|
||||
// search for a context checkpoint
|
||||
@ -3248,15 +3270,17 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float)checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
|
||||
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
||||
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
|
||||
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
||||
slot.n_past = slot.cache_tokens.size_up_to_pos(pos_next);
|
||||
|
||||
pos_next = slot.prompt_tokens.pos_next(slot.n_past_prompt);
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
||||
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(pos_next);
|
||||
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, it->n_tokens, slot.n_past, (float)checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3267,6 +3291,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
slot.n_past_prompt = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
pos_next = 0;
|
||||
common_sampler_reset(slot.ctx_sampling);
|
||||
}
|
||||
}
|
||||
@ -3276,7 +3301,7 @@ void server_context::apply_checkpoint(server_slot & slot) {
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
if (cur.pos_max > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
it = slot.server_cached_prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
@ -3292,10 +3317,10 @@ bool server_context::create_checkpoint(server_slot & slot) {
|
||||
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.cache_tokens.n_tokens() >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
|
||||
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || slot.cache_tokens.n_tokens() > slot.server_cached_prompt.checkpoints.back().n_tokens);
|
||||
|
||||
if (do_checkpoint) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
@ -3303,26 +3328,17 @@ bool server_context::create_checkpoint(server_slot & slot) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
|
||||
server_prompt_checkpoint_update(cur, ctx, slot.id, slot.cache_tokens.n_tokens(), pos_min, pos_max, slot.n_past_offset);
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
|
||||
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB, took %.2f ms)\n",
|
||||
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024,
|
||||
(ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
return do_checkpoint;
|
||||
@ -3885,7 +3901,7 @@ void server_context::speculative_decoding_accept() {
|
||||
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
||||
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
|
||||
}
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
|
||||
discard_speculative_checkpoint(slot, ctx);
|
||||
}
|
||||
|
||||
@ -4105,7 +4121,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
||||
slot.n_past = slot.cache_tokens.n_tokens();
|
||||
|
||||
// Remove from KV cache
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
|
||||
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
|
||||
|
||||
// Truncate buffer
|
||||
slot.token_buffer.resize(n_keep_buffer);
|
||||
|
||||
@ -353,6 +353,8 @@ struct server_prompt_checkpoint {
|
||||
llama_pos pos_min_prompt;
|
||||
llama_pos pos_max_prompt;
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
@ -365,6 +367,7 @@ struct server_prompt_checkpoint {
|
||||
j["pos_max"] = pos_max;
|
||||
j["pos_min_prompt"] = pos_min_prompt;
|
||||
j["pos_max_prompt"] = pos_max_prompt;
|
||||
j["n_tokens"] = n_tokens;
|
||||
return j;
|
||||
}
|
||||
|
||||
@ -373,6 +376,7 @@ struct server_prompt_checkpoint {
|
||||
pos_max = j.value<llama_pos>("pos_max", 0);
|
||||
pos_min_prompt = j.value<llama_pos>("pos_min_prompt", 0);
|
||||
pos_max_prompt = j.value<llama_pos>("pos_max_prompt", 0);
|
||||
n_tokens = j.value<int64_t>("n_tokens", 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user