server: create checkpoint on task cancel

This commit is contained in:
Xuan Son Nguyen 2026-05-28 23:19:01 +02:00
parent 751ebd17a5
commit 24c307d261

View File

@ -68,6 +68,7 @@ struct server_slot {
llama_tokens spec_prompt;
std::vector<int32_t> spec_i_batch;
common_prompt_checkpoint spec_ckpt;
int32_t n_ctx_checkpoints = 0;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@ -106,6 +107,10 @@ struct server_slot {
// state
slot_state state = SLOT_STATE_IDLE;
// this stores the processed prompt tokens
// during SLOT_STATE_PROCESSING_PROMPT, this is populated progressively as the prompt is processed
// note: inside update_slots(), tokens is appended, but will only become validate in the KV cache after llama_decode()
// outside of update_slots(), prompt tokens are guaranteed to be valid
server_prompt prompt;
void prompt_save(server_prompt_cache & prompt_cache) const {
@ -354,6 +359,33 @@ struct server_slot {
prompt.tokens.insert(spec_draft);
}
// n_tokens_cur: the number of tokens added to the batch for the current slot
void create_checkpoint(const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
// TODO @ngxson : avoid create 2 checkpoints for exactly the same prompt
while (prompt.checkpoints.size() >= (size_t) n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = prompt.checkpoints.front();
SLT_WRN(*this, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
prompt.checkpoints.erase(prompt.checkpoints.begin());
}
auto & cur = prompt.checkpoints.emplace_back();
cur.update_pos(prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
cur.update_tgt(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_INF(*this,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) prompt.checkpoints.size(), n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
}
void release() {
if (is_processing()) {
GGML_ASSERT(task);
@ -376,6 +408,17 @@ struct server_slot {
}
}
void cancel() {
if (is_processing()) {
// create a checkpoint, so that progress is not lost
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id);
// note: n_tokens_cur = 0 because we are outside of update_slots()
create_checkpoint(0, pos_min, pos_max);
}
release();
}
result_timings get_timings() const {
result_timings timings;
timings.cache_n = n_prompt_tokens_cache;
@ -1059,6 +1102,8 @@ private:
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
slot.n_ctx_checkpoints = params_base.n_ctx_checkpoints;
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int id_slot) {
@ -1993,31 +2038,6 @@ private:
return true;
}
// n_tokens_cur: the number of tokens added to the batch for the current slot
void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
auto & cur = slot.prompt.checkpoints.emplace_back();
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_INF(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
}
void process_single_task(server_task && task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
@ -2084,10 +2104,9 @@ private:
} break;
case SERVER_TASK_TYPE_CANCEL:
{
// release slot linked with the task id
for (auto & slot : slots) {
if (slot.task && slot.task->id == task.id_target) {
slot.release();
slot.cancel();
break;
}
}
@ -3047,7 +3066,7 @@ private:
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
create_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
slot.create_checkpoint(n_tokens_cur, pos_min, pos_max);
}
}