mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server: create checkpoint on task cancel
This commit is contained in:
parent
751ebd17a5
commit
24c307d261
@ -68,6 +68,7 @@ struct server_slot {
|
||||
llama_tokens spec_prompt;
|
||||
std::vector<int32_t> spec_i_batch;
|
||||
common_prompt_checkpoint spec_ckpt;
|
||||
int32_t n_ctx_checkpoints = 0;
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||
@ -106,6 +107,10 @@ struct server_slot {
|
||||
// state
|
||||
slot_state state = SLOT_STATE_IDLE;
|
||||
|
||||
// this stores the processed prompt tokens
|
||||
// during SLOT_STATE_PROCESSING_PROMPT, this is populated progressively as the prompt is processed
|
||||
// note: inside update_slots(), tokens is appended, but will only become validate in the KV cache after llama_decode()
|
||||
// outside of update_slots(), prompt tokens are guaranteed to be valid
|
||||
server_prompt prompt;
|
||||
|
||||
void prompt_save(server_prompt_cache & prompt_cache) const {
|
||||
@ -354,6 +359,33 @@ struct server_slot {
|
||||
prompt.tokens.insert(spec_draft);
|
||||
}
|
||||
|
||||
// n_tokens_cur: the number of tokens added to the batch for the current slot
|
||||
void create_checkpoint(const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
|
||||
// TODO @ngxson : avoid create 2 checkpoints for exactly the same prompt
|
||||
|
||||
while (prompt.checkpoints.size() >= (size_t) n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(*this, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
|
||||
prompt.checkpoints.erase(prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
auto & cur = prompt.checkpoints.emplace_back();
|
||||
|
||||
cur.update_pos(prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
cur.update_tgt(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
cur.update_dft(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_INF(*this,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) prompt.checkpoints.size(), n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
void release() {
|
||||
if (is_processing()) {
|
||||
GGML_ASSERT(task);
|
||||
@ -376,6 +408,17 @@ struct server_slot {
|
||||
}
|
||||
}
|
||||
|
||||
void cancel() {
|
||||
if (is_processing()) {
|
||||
// create a checkpoint, so that progress is not lost
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), id);
|
||||
// note: n_tokens_cur = 0 because we are outside of update_slots()
|
||||
create_checkpoint(0, pos_min, pos_max);
|
||||
}
|
||||
release();
|
||||
}
|
||||
|
||||
result_timings get_timings() const {
|
||||
result_timings timings;
|
||||
timings.cache_n = n_prompt_tokens_cache;
|
||||
@ -1059,6 +1102,8 @@ private:
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
slot.n_ctx_checkpoints = params_base.n_ctx_checkpoints;
|
||||
|
||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
|
||||
slot.callback_on_release = [this](int id_slot) {
|
||||
@ -1993,31 +2038,6 @@ private:
|
||||
return true;
|
||||
}
|
||||
|
||||
// n_tokens_cur: the number of tokens added to the batch for the current slot
|
||||
void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
|
||||
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_INF(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
void process_single_task(server_task && task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
@ -2084,10 +2104,9 @@ private:
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_CANCEL:
|
||||
{
|
||||
// release slot linked with the task id
|
||||
for (auto & slot : slots) {
|
||||
if (slot.task && slot.task->id == task.id_target) {
|
||||
slot.release();
|
||||
slot.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -3047,7 +3066,7 @@ private:
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
if (do_checkpoint) {
|
||||
create_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
|
||||
slot.create_checkpoint(n_tokens_cur, pos_min, pos_max);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user