From d5037c508a6d60455d1dca95daba1b1a8edb0138 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 20 Jun 2026 16:35:57 +0200 Subject: [PATCH] server: refactor batch construction --- tools/server/server-context.cpp | 705 ++++++++++++++++++-------------- 1 file changed, 387 insertions(+), 318 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3de1335ec2..a4e72709c5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -63,6 +63,71 @@ enum slot_state { SLOT_STATE_GENERATING, }; +struct server_slot; // forward declaration + +struct server_batch { + llama_batch batch; + + struct token { + int32_t id_slot; + llama_token token; + llama_pos pos; + bool output; + }; + std::vector tokens; + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + server_batch() { + batch.token = nullptr; + } + server_batch(int32_t n_tokens_alloc) { + batch = llama_batch_init(n_tokens_alloc, 0, 1); + tokens.reserve(n_tokens_alloc); + } + + ~server_batch() { + llama_batch_free(batch); + } + + bool add(int32_t id_slot, llama_token token, llama_pos pos, bool output) { + GGML_ASSERT(batch.token != nullptr); + if (tokens.size() >= tokens.capacity()) { + return false; + } + tokens.push_back({ id_slot, token, pos, output }); + return true; + } + + void clear() { + tokens.clear(); + slot_batched = nullptr; + alora_scale = -1.0f; + alora_disabled_id = 0; + } + + int32_t size() const { + return (int32_t)tokens.size(); + } + + void set_output(int32_t idx, bool output) { + GGML_ASSERT(idx >= 0 && idx < (int32_t)tokens.size()); + tokens[idx].output = output; + } + + llama_batch render() { + for (int32_t i = 0; i < size(); i++) { + const auto & t = tokens[i]; + common_batch_add(batch, t.token, t.pos, { t.id_slot }, t.output); + } + return batch; + } +}; + struct server_slot { int id; @@ -348,12 +413,13 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch) { + // add sampled token of this slot to the batch, optionally add the speculative draft tokens if any + void handle_last_sampled_token(server_batch & batch) { if (spec_draft.empty()) { // no speculative decoding - i_batch = batch.n_tokens; + i_batch = batch.size(); - common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true); + batch.add(id, sampled, prompt.tokens.pos_next(), true); SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", sampled, n_ctx, prompt.n_tokens(), truncated); @@ -363,16 +429,16 @@ struct server_slot { GGML_ASSERT(spec_i_batch.empty()); - spec_i_batch.push_back(batch.n_tokens); + spec_i_batch.push_back(batch.size()); for (size_t i = 0; i < spec_draft.size(); i++) { - spec_i_batch.push_back(batch.n_tokens + i + 1); + spec_i_batch.push_back(batch.size() + i + 1); } auto pos0 = prompt.tokens.pos_next(); - common_batch_add(batch, sampled, pos0++, { this->id }, true); + batch.add(id, sampled, pos0++, true); for (auto token : spec_draft) { - common_batch_add(batch, token, pos0++, { this->id }, true); + batch.add(this->id, token, pos0++, true); } } @@ -793,7 +859,7 @@ private: llama_context * ctx_tgt = nullptr; - llama_batch batch {}; + server_batch batch; llama_model_ptr model_dft; llama_context_ptr ctx_dft; @@ -845,8 +911,6 @@ private: mtmd_free(mctx); mctx = nullptr; - - llama_batch_free(batch); } void handle_sleeping_state(bool new_state) { @@ -1214,7 +1278,7 @@ private: // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx_tgt); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch = server_batch(std::max(n_batch, params_base.n_parallel)); } if (params_base.cache_ram_mib != 0) { @@ -2511,19 +2575,39 @@ private: if (all_idle) { SRV_INF("%s", "all slots are idle\n"); + return; // skip further processing - return; + } else { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); } } - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); + llama_batch batch_view; + try { + pre_decode(); + batch_view = batch.render(); + } catch (const std::exception & e) { + SRV_ERR("pre_decode() failed: %s\n", e.what()); } + try { + decode(llama_n_batch(ctx_tgt), batch_view); + } catch (const std::exception & e) { + SRV_ERR("decode() failed: %s\n", e.what()); + } + + try { + post_decode(batch_view); + } catch (const std::exception & e) { + SRV_ERR("post_decode() failed: %s\n", e.what()); + } + } + + void pre_decode() { // apply context-shift if needed // TODO: simplify and improve for (server_slot & slot : slots) { @@ -2594,10 +2678,10 @@ private: } // start populating the batch for this iteration - common_batch_clear(batch); + batch.clear(); // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; + auto & slot_batched = batch.slot_batched; std::vector generating; std::vector drafting; @@ -2719,18 +2803,18 @@ private: for (auto * slot_ptr : generating) { auto & slot = *slot_ptr; - slot.update_batch(batch); + slot.handle_last_sampled_token(batch); } // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx_tgt); int32_t n_ubatch = llama_n_ubatch(ctx_tgt); - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { + if (params_base.cont_batching || batch.size() == 0) { for (auto & slot : slots) { if (!slot.is_processing()) { continue; @@ -2752,7 +2836,7 @@ private: const auto & input_tokens = slot.task->tokens; // used to determine the number of tokens added to the batch for the current slot - const auto n_tokens_prev = batch.n_tokens; + const auto n_tokens_prev = batch.size(); // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3048,7 +3132,7 @@ private: if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { + if (batch.size() + slot.task->n_tokens() > n_batch) { continue; } } @@ -3133,7 +3217,7 @@ private: const bool n_before_user_known = n_before_user > 0; // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { // get next token to process llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3151,10 +3235,9 @@ private: // embedding requires all tokens in the batch to be output; // MTP also wants logits at every prompt position so the // streaming hook can mirror t_h_nextn into ctx_dft. - common_batch_add(batch, + batch.add(slot.id, cur_tok, slot.prompt.tokens.pos_next(), - { slot.id }, slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); @@ -3190,7 +3273,7 @@ private: } // the number of tokens added to the batch for the current slot - const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const auto n_tokens_cur = batch.size() - n_tokens_prev; const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; @@ -3198,13 +3281,13 @@ private: if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(batch.size() > 0); // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.set_output(batch.size() - 1, true); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.size() - 1; slot.init_sampler(); } else { @@ -3264,19 +3347,21 @@ private: slot_batched = &slot; } - if (batch.n_tokens >= n_batch) { + if (batch.size() >= n_batch) { break; } } } + } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + bool decode(int32_t n_batch, llama_batch & batch_view) { + SRV_DBG("decoding batch, n_tokens = %d\n", batch.size()); - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + auto & slot_batched = batch.slot_batched; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; + // TODO @ngxson : alora handling is too messy, need to refactor it to be more clear and maintainable if (slot_batched) { // apply lora, only need to do it once per batch common_set_adapter_lora(ctx_tgt, slot_batched->lora); @@ -3291,7 +3376,7 @@ private: llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } - if (batch.n_tokens == 0) { + if (batch.size() == 0) { SRV_WRN("%s", "no tokens to decode\n"); if (++n_empty_consecutive > 3) { @@ -3301,330 +3386,314 @@ private: n_empty_consecutive = 0; } - int32_t i_next = 0; + const int ret = llama_decode(ctx_tgt, batch_view); - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + metrics.on_decoded(slots); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + if (ret != 0) { + { + std::string err; - const int ret = llama_decode(ctx_tgt, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - slot.prompt_clear(false); - } - } - - break; - } + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; } - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; + if (ret == -1) { + err = "Invalid input batch."; } - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } - continue; // continue loop of n_batch - } + // TODO: handle ret == 2 (abort) when we start aborting - // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] - // for now, always re-evaluate for simplicity - // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 - if (!common_speculative_process(spec.get(), batch_view)) { - SRV_ERR("%s", "failed to process speculative batch\n"); + if (!err.empty()) { + SRV_ERR("%s n_batch = %d, ret = %d\n", err.c_str(), n_batch, ret); - // TODO: handle error - break; - } + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx_tgt); - - // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { - std::vector children; - for (auto & other : slots) { - if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { - children.push_back(&other); + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + slot.prompt_clear(false); } } - // all children slots should already launched by launch_slots_with_parent_task() - // copy state to the child slots - for (auto & child : children) { - SLT_INF(slot, " - copying state to child %d\n", child->id); - - GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); - - slot.copy_state_to(*child); - child->state = SLOT_STATE_DONE_PROMPT; - } + // stop, do not retry with smaller batch size + throw std::runtime_error(err); } } - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + // TODO @ngxson : handle sub-batching + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, n_batch = %d, ret = %d\n", n_batch, ret); + + return false; // retry with smaller batch size + } + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + return false; + } + + // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too + for (auto & slot : slots) { + if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { + std::vector children; + for (auto & other : slots) { + if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { + children.push_back(&other); } } - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + // all children slots should already launched by launch_slots_with_parent_task() + // copy state to the child slots + for (auto & child : children) { + SLT_INF(slot, " - copying state to child %d\n", child->id); + + GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); + + slot.copy_state_to(*child); + child->state = SLOT_STATE_DONE_PROMPT; + } + } + } + + return true; + } + + void post_decode(llama_batch & batch_view) { + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + for (auto & slot : slots) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + // TODO @ngxson : bring back slot.i_batch check when sub-batching is implemented + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; continue; // continue loop of slots } - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - GGML_ASSERT(slot.task->need_sampling()); - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - - if (slot.can_speculate()) { - common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); - } - } else if (slot.state != SLOT_STATE_GENERATING) { + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; continue; // continue loop of slots } - if (slot.can_speculate() && !slot.spec_draft.empty()) { - continue; // sample using speculative decoding + GGML_ASSERT(slot.task->need_sampling()); + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + + if (slot.can_speculate()) { + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); + } + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + if (slot.can_speculate() && !slot.spec_draft.empty()) { + continue; // sample using speculative decoding + } + + const int tok_idx = slot.i_batch; + + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl.get(), id, true); + + // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement + const int64_t t_now = ggml_time_us(); + + slot.n_decoded += 1; + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_now; + slot.t_print_last = t_now; + slot.n_decoded_last = 0; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + continue; + } + + slot.print_timings_tg(); + } + + // speculative decoding - main model sample and accept + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { + continue; + } + + // save the original draft size + const size_t n_draft = slot.spec_draft.size(); + + GGML_ASSERT(n_draft > 0); + + // verify and try to accept the draft + { + // save the sampler sampler state in case we need to restore it + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + + GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); + slot.spec_i_batch.clear(); + + GGML_ASSERT(accepted.size() >= 1); + + const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); + + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); + + // check for partial draft acceptance + if (n_rollback > 0) { + if (use_ckpt_tgt) { + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); + } + + // partial acceptance is not supported by the context -> truncate the draft and restore the state + slot.spec_draft = std::move(accepted); + + const auto & ckpt = slot.spec_ckpt; + + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); + } + + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); + } + + slot.prompt.tokens.keep_first(ckpt.n_tokens); + slot.smpl = std::move(smpl_save); + + continue; + } } - const int tok_idx = slot.i_batch - i; + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); + } - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - slot.i_batch = -1; + slot.spec_draft = std::move(accepted); + } - common_sampler_accept(slot.smpl.get(), id, true); + const int64_t t_now = ggml_time_us(); - // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement - const int64_t t_now = ggml_time_us(); + const auto ids = std::move(slot.spec_draft); + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + slot.n_draft_verif_steps += 1; + + if (slot.n_accepted_per_pos.empty()) { + slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); + } + for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { + slot.n_accepted_per_pos[i]++; + } + + // add accepted tokens to the prompt + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); + + common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); + } + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = t_now; - slot.t_print_last = t_now; - slot.n_decoded_last = 0; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - if (!process_token(result, slot)) { - // release slot because of stop condition slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); slot.release(); - continue; + break; } - - slot.print_timings_tg(); } - // speculative decoding - main model sample and accept - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { - continue; - } + slot.print_timings_tg(); - // save the original draft size - const size_t n_draft = slot.spec_draft.size(); - - GGML_ASSERT(n_draft > 0); - - // verify and try to accept the draft - { - // save the sampler sampler state in case we need to restore it - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); - - GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); - slot.spec_i_batch.clear(); - - GGML_ASSERT(accepted.size() >= 1); - - const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); - - const bool use_ckpt_tgt = - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); - - // check for partial draft acceptance - if (n_rollback > 0) { - if (use_ckpt_tgt) { - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); - } - - // partial acceptance is not supported by the context -> truncate the draft and restore the state - slot.spec_draft = std::move(accepted); - - const auto & ckpt = slot.spec_ckpt; - - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - - { - ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); - } - - if (slot.ctx_dft) { - ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); - } - - slot.prompt.tokens.keep_first(ckpt.n_tokens); - slot.smpl = std::move(smpl_save); - - continue; - } - } - - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); - } - - common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - - slot.spec_draft = std::move(accepted); - } - - const int64_t t_now = ggml_time_us(); - - const auto ids = std::move(slot.spec_draft); - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - slot.n_draft_verif_steps += 1; - - if (slot.n_accepted_per_pos.empty()) { - slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); - } - for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { - slot.n_accepted_per_pos[i]++; - } - - // add accepted tokens to the prompt - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - slot.sampled = ids.back(); // last accepted token - SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - - common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); - if (slot.ctx_dft) { - common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); - } - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - slot.n_decoded += 1; - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - slot.print_timings_tg(); - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); - } + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); } - - SRV_DBG("%s", "run slots completed\n"); } int get_slot_n_ctx() {