diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aeb15096c8..91a8eb9452 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -63,6 +63,99 @@ enum slot_state { SLOT_STATE_GENERATING, }; +struct server_slot; // forward declaration + +struct server_batch { + llama_batch batch; + bool batch_rendered = false; + + struct token { + int32_t id_slot; + llama_token token; + llama_pos pos; + bool output; + }; + std::vector tokens; + int32_t n_tokens_alloc = 0; + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + server_batch() { + batch.token = nullptr; // sentinel: uninitialized batch + } + + ~server_batch() { + llama_batch_free(batch); + } + + void init(int32_t n_tokens_alloc) { + this->n_tokens_alloc = n_tokens_alloc; + batch = llama_batch_init(n_tokens_alloc, 0, 1); + tokens.reserve(n_tokens_alloc); + } + + bool add(int32_t id_slot, llama_token token, llama_pos pos, bool output) { + GGML_ASSERT(batch.token != nullptr); + if ((int32_t)tokens.size() >= n_tokens_alloc) { + return false; + } + // LOG_INF("adding token to batch: slot=%d, token=%d, pos=%d, output=%d\n", id_slot, token, pos, output); + tokens.push_back({ id_slot, token, pos, output }); + return true; + } + + void clear() { + tokens.clear(); + common_batch_clear(batch); + slot_batched = nullptr; + alora_scale = -1.0f; + alora_disabled_id = 0; + batch_rendered = false; + } + + int32_t size() const { + return (int32_t)tokens.size(); + } + + void set_output(int32_t idx, bool output) { + GGML_ASSERT(idx >= 0 && idx < (int32_t)tokens.size()); + tokens[idx].output = output; + } + + void render() { + GGML_ASSERT(batch.token != nullptr); + common_batch_clear(batch); + for (int32_t i = 0; i < size(); i++) { + const auto & t = tokens[i]; + common_batch_add(batch, t.token, t.pos, { t.id_slot }, t.output); + } + batch_rendered = true; + } + + llama_batch get_view(int32_t off, int32_t n_tokens) const { + GGML_ASSERT(batch.token != nullptr); + GGML_ASSERT(batch_rendered); + GGML_ASSERT(off >= 0 && off < size()); + GGML_ASSERT(n_tokens > 0 && off + n_tokens <= size()); + + llama_batch view = { + n_tokens, + batch.token + off, + nullptr, + batch.pos + off, + batch.n_seq_id + off, + batch.seq_id + off, + batch.logits + off, + }; + + return view; + } +}; + struct server_slot { int id; @@ -185,6 +278,7 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character + // TODO @ngxson : move all metrics to a sub-struct for clarity int64_t t_start_process_prompt; int64_t t_start_generation; int64_t t_print_last = 0; @@ -348,12 +442,14 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch) { + // add sampled token of this slot to the batch, optionally add the speculative draft tokens if any + void handle_last_sampled_token(server_batch & batch) { + bool add_ok = true; if (spec_draft.empty()) { // no speculative decoding - i_batch = batch.n_tokens; + i_batch = batch.size(); - common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true); + add_ok &= batch.add(id, sampled, prompt.tokens.pos_next(), true); SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", sampled, n_ctx, prompt.n_tokens(), truncated); @@ -363,19 +459,21 @@ struct server_slot { GGML_ASSERT(spec_i_batch.empty()); - spec_i_batch.push_back(batch.n_tokens); + spec_i_batch.push_back(batch.size()); for (size_t i = 0; i < spec_draft.size(); i++) { - spec_i_batch.push_back(batch.n_tokens + i + 1); + spec_i_batch.push_back(batch.size() + i + 1); } auto pos0 = prompt.tokens.pos_next(); - common_batch_add(batch, sampled, pos0++, { this->id }, true); + add_ok &= batch.add(id, sampled, pos0++, true); for (auto token : spec_draft) { - common_batch_add(batch, token, pos0++, { this->id }, true); + add_ok &= batch.add(this->id, token, pos0++, true); } } + GGML_ASSERT(add_ok && "batch must be large enough to hold the sampled and draft tokens"); + prompt.tokens.push_back(sampled); prompt.tokens.insert(spec_draft); } @@ -793,7 +891,7 @@ private: llama_context * ctx_tgt = nullptr; - llama_batch batch {}; + server_batch batch; llama_model_ptr model_dft; llama_context_ptr ctx_dft; @@ -845,8 +943,6 @@ private: mtmd_free(mctx); mctx = nullptr; - - llama_batch_free(batch); } void handle_sleeping_state(bool new_state) { @@ -1266,7 +1362,7 @@ private: // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx_tgt); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch.init(std::max(n_batch, params_base.n_parallel)); } if (params_base.cache_ram_mib != 0) { @@ -2556,7 +2652,83 @@ private: } } + void iterate(std::vector & slots, std::function callback) { + for (auto & slot : slots) { + try { + callback(slot); + } catch (const std::exception & e) { + SLT_ERR(slot, "got exception: %s\n", e.what()); + send_error(slot, std::string("got exception: ") + e.what(), ERROR_TYPE_SERVER); + slot.release(); + } + } + } + + void iterate(std::vector & slots, std::function callback) { + for (auto & slot : slots) { + try { + callback(*slot); + } catch (const std::exception & e) { + SLT_ERR(*slot, "got exception: %s\n", e.what()); + send_error(*slot, std::string("got exception: ") + e.what(), ERROR_TYPE_SERVER); + slot->release(); + } + } + } + + void abort_all_slots(const std::string & reason) { + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, reason, ERROR_TYPE_SERVER); + slot.release(); + } + } + } + + // @ngxson : for debugging only + int64_t t_pre_decode = 0; + int64_t t_decode = 0; + int64_t t_post_decode = 0; + int64_t t_sampl = 0; + int64_t n_pre_decode = 0; + int64_t n_decode = 0; + int64_t n_post_decode = 0; + int64_t n_sampl = 0; +// #define DEBUG_TIMINGS +#ifdef DEBUG_TIMINGS + struct scoped_timer { + int64_t & t; + int64_t & n; + int64_t t_start; + scoped_timer(int64_t & t_, int64_t & n_) : t(t_), n(n_) { + t_start = ggml_time_us(); + } + ~scoped_timer() { + t += ggml_time_us() - t_start; + n++; + } + }; +#else + struct scoped_timer { + scoped_timer(int64_t &, int64_t &) {} + ~scoped_timer() {} + }; +#endif + void update_slots() { +#ifdef DEBUG_TIMINGS + static int64_t t_prev = 0; + int64_t t_start = ggml_time_us(); + if (t_start - t_prev > 5 * 1000 * 1000) { // every 5 seconds + t_prev = t_start; + SRV_INF("n_pre_decode = %" PRId64 "\n", n_pre_decode); + SRV_INF("avg t_pre_decode = %f ms\n", (double) t_pre_decode / n_pre_decode / 1000.0); + SRV_INF("avg t_decode = %f ms\n", (double) t_decode / n_decode / 1000.0); + SRV_INF("avg t_post_decode = %f ms\n", (double) t_post_decode / n_post_decode / 1000.0); + SRV_INF("avg t_sampl = %f ms\n", (double) t_sampl / n_sampl / 1000.0); + } +#endif + // check if all slots are idle { bool all_idle = true; @@ -2570,29 +2742,80 @@ private: if (all_idle) { SRV_INF("%s", "all slots are idle\n"); + return; // skip further processing - return; + } else { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); } } - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); + try { + scoped_timer t(t_pre_decode, n_pre_decode); + pre_decode(); + batch.render(); + } catch (const std::exception & e) { + SRV_ERR("pre_decode() failed: %s\n", e.what()); + abort_all_slots("pre_decode() failed: " + std::string(e.what())); } + llama_batch batch_view; + int32_t off_next = 0; + int32_t n_batch = llama_n_batch(ctx_tgt); + for (int32_t off = 0; off < batch.size(); off = off_next) { + const int32_t n_tokens = std::min(n_batch, batch.size() - off); + try { + scoped_timer t(t_decode, n_decode); + // TODO @ngxson : maybe handle n_batch == 1 here instead of inside decode() + + batch_view = batch.get_view(off, n_tokens); + bool ok = decode(n_batch, off, batch_view); +#ifdef DEBUG_TIMINGS + llama_synchronize(ctx_tgt); +#endif + + if (ok) { + // move the head of the batch forward with the number of tokens we just processed + off_next = off + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx_tgt); + } else { + // try again with the updated n_batch + continue; + } + } catch (const std::exception & e) { + SRV_ERR("decode() failed: %s\n", e.what()); + abort_all_slots("decode() failed: " + std::string(e.what())); + break; // stop any further processing + } + + try { + scoped_timer t(t_post_decode, n_post_decode); + post_decode(n_tokens, off, batch_view); + } catch (const std::exception & e) { + SRV_ERR("post_decode() failed: %s\n", e.what()); + abort_all_slots("post_decode() failed: " + std::string(e.what())); + break; // stop any further processing + } + + } + } + + void pre_decode() { // apply context-shift if needed // TODO: simplify and improve - for (server_slot & slot : slots) { + iterate(slots, [&](server_slot & slot) { if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (mctx) { @@ -2604,7 +2827,7 @@ private: if (slot.task->is_parent() || slot.task->is_child()) { send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } // Shift context @@ -2650,28 +2873,28 @@ private: slot.truncated = true; } - } + }); // start populating the batch for this iteration - common_batch_clear(batch); + batch.clear(); // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; + auto & slot_batched = batch.slot_batched; std::vector generating; std::vector drafting; // determine which slots are generating and drafting - for (auto & slot : slots) { + iterate(slots, [&](server_slot & slot) { if (slot.state != SLOT_STATE_GENERATING) { - continue; + return; } // check if we can batch this slot with the previous one if (!slot_batched) { slot_batched = &slot; } else if (!slot_batched->can_batch_with(slot)) { - continue; + return; } generating.push_back(&slot); @@ -2719,7 +2942,7 @@ private: } } } - } + }); // generate the actual drafts (if any) { @@ -2727,9 +2950,7 @@ private: } // make checkpoints if needed - for (auto * slot_ptr : drafting) { - auto & slot = *slot_ptr; - + iterate(drafting, [&](server_slot & slot) { auto & draft = slot.spec_draft; auto & ckpt = slot.spec_ckpt; @@ -2772,38 +2993,42 @@ private: ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); } } - } + }); // update the batch with the sampled/drafted tokens - for (auto * slot_ptr : generating) { - auto & slot = *slot_ptr; - - slot.update_batch(batch); - } + iterate(generating, [&](server_slot & slot) { + slot.handle_last_sampled_token(batch); + }); // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx_tgt); int32_t n_ubatch = llama_n_ubatch(ctx_tgt); - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { + if (params_base.cont_batching || batch.size() == 0) { + bool add_ok = true; // false means the batch is full, skip remaining slots + + iterate(slots, [&](server_slot & slot) { + if (!add_ok || batch.size() >= n_batch) { + return; // batch is full, skip remaining slots + } + if (!slot.is_processing()) { - continue; + return; } // check if we can batch this slot with the previous one if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; + return; } // check if this is a child slot if (slot.state == SLOT_STATE_WAIT_OTHER) { SLT_DBG(slot, "%s", "waiting for parent slot to complete\n"); - continue; + return; } // this slot still has a prompt to be processed @@ -2811,7 +3036,7 @@ private: const auto & input_tokens = slot.task->tokens; // used to determine the number of tokens added to the batch for the current slot - const auto n_tokens_prev = batch.n_tokens; + const auto n_tokens_prev = batch.size(); // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -2847,14 +3072,14 @@ private: send_final_response(slot); slot.release(); - continue; + return; } // TODO: support memory-less logits computation if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (!slot.can_split()) { @@ -2866,7 +3091,7 @@ private: slot.task->n_tokens(), n_ubatch), ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (slot.task->n_tokens() > slot.n_ctx) { @@ -2877,7 +3102,7 @@ private: slot.task->n_tokens(), slot.n_ctx), ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); - continue; + return; } } else { if (slot.task->n_tokens() >= slot.n_ctx) { @@ -2887,7 +3112,7 @@ private: slot.task->n_tokens(), slot.n_ctx), ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); - continue; + return; } if (slot.task->params.cache_prompt) { @@ -3107,8 +3332,8 @@ private: if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; + if (batch.size() + slot.task->n_tokens() > n_batch) { + return; } } @@ -3192,7 +3417,7 @@ private: const bool n_before_user_known = n_before_user > 0; // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { // get next token to process llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3210,10 +3435,9 @@ private: // embedding requires all tokens in the batch to be output; // MTP also wants logits at every prompt position so the // streaming hook can mirror t_h_nextn into ctx_dft. - common_batch_add(batch, + add_ok &= batch.add(slot.id, cur_tok, slot.prompt.tokens.pos_next(), - { slot.id }, slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); @@ -3249,7 +3473,7 @@ private: } // the number of tokens added to the batch for the current slot - const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const auto n_tokens_cur = batch.size() - n_tokens_prev; const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; @@ -3257,13 +3481,13 @@ private: if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(batch.size() > 0); // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.set_output(batch.size() - 1, true); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.size() - 1; slot.init_sampler(); } else { @@ -3322,20 +3546,20 @@ private: if (!slot_batched) { slot_batched = &slot; } - - if (batch.n_tokens >= n_batch) { - break; - } - } + }); } + } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + // returns true = success ; false = retry with smaller batch size + // throw std::runtime_error on fatal error + bool decode(int32_t & n_batch, int32_t off, llama_batch & batch_view) { + SRV_DBG("n_batch (effective) = %d, off = %d\n", n_batch, off); - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + auto & slot_batched = batch.slot_batched; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; + // TODO @ngxson : alora handling is too messy, need to refactor it to be more clear and maintainable if (slot_batched) { // apply lora, only need to do it once per batch common_set_adapter_lora(ctx_tgt, slot_batched->lora); @@ -3350,340 +3574,348 @@ private: llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } - if (batch.n_tokens == 0) { + if (batch.size() == 0) { SRV_WRN("%s", "no tokens to decode\n"); if (++n_empty_consecutive > 3) { GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); } + + return true; // nothing to decode } else { n_empty_consecutive = 0; } - int32_t i_next = 0; + const int ret = llama_decode(ctx_tgt, batch_view); - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + metrics.on_decoded(slots); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + if (ret != 0) { + { + std::string err; - const int ret = llama_decode(ctx_tgt, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - slot.prompt_clear(false); - } - } - - break; - } + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; } - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; + if (ret == -1) { + err = "Invalid input batch."; } - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } - continue; // continue loop of n_batch - } + // TODO: handle ret == 2 (abort) when we start aborting - // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] - // for now, always re-evaluate for simplicity - // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 - if (!common_speculative_process(spec.get(), batch_view)) { - SRV_ERR("%s", "failed to process speculative batch\n"); + if (!err.empty()) { + SRV_ERR("%s off = %d, n_batch = %d, ret = %d\n", err.c_str(), off, n_batch, ret); - // TODO: handle error - break; - } + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx_tgt); - - // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { - std::vector children; - for (auto & other : slots) { - if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { - children.push_back(&other); + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + slot.prompt_clear(false); } } - // all children slots should already launched by launch_slots_with_parent_task() - // copy state to the child slots - for (auto & child : children) { - SLT_INF(slot, " - copying state to child %d\n", child->id); - - GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); - - slot.copy_state_to(*child); - child->state = SLOT_STATE_DONE_PROMPT; - } + // stop, do not retry with smaller batch size + throw std::runtime_error(err); } } - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, off = %d, n_batch = %d, ret = %d\n", off, n_batch, ret); + + return false; // retry with the updated n_batch + } + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + throw std::runtime_error("failed to process speculative batch"); + } + + // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too + for (auto & slot : slots) { + if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { + std::vector children; + for (auto & other : slots) { + if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { + children.push_back(&other); } } - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots + // all children slots should already launched by launch_slots_with_parent_task() + // copy state to the child slots + for (auto & child : children) { + SLT_INF(slot, " - copying state to child %d\n", child->id); + + GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); + + slot.copy_state_to(*child); + child->state = SLOT_STATE_DONE_PROMPT; + } + } + } + + return true; + } + + void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) { + // for checking if a given batch index is inside batch_view + auto is_inside_view = [&](int32_t idx) { + return idx >= off && idx < off + n_batch_tokens; + }; + + // TODO @ngxson : it's tricky to make sub-batch compatible with common_sampler_sample_and_accept_n, + // so for now we will throw an error in this case: https://github.com/ggml-org/llama.cpp/issues/24840 + iterate(slots, [&](server_slot & slot) { + for (auto & i : slot.spec_i_batch) { + if (!is_inside_view(i)) { + throw std::runtime_error(string_format("speculative batch index %d is not inside the current sub-batch [%d, %d)", i, off, off + n_batch_tokens)); + } + } + }); + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + iterate(slots, [&](server_slot & slot) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + if (!is_inside_view(slot.i_batch)) { + // the required token not in this sub-batch, skip + return; + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + return; } - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - GGML_ASSERT(slot.task->need_sampling()); - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - - if (slot.can_speculate()) { - common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); - } - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + return; } - if (slot.can_speculate() && !slot.spec_draft.empty()) { - continue; // sample using speculative decoding + GGML_ASSERT(slot.task->need_sampling()); + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + + if (slot.can_speculate()) { + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); + } + } else if (slot.state != SLOT_STATE_GENERATING) { + return; + } + + if (slot.can_speculate() && !slot.spec_draft.empty()) { + return; // sample using speculative decoding + } + + // shifted according to the current sub-batch + const int tok_idx = slot.i_batch - off; + + llama_token id; + { + scoped_timer timer(t_sampl, n_sampl); + id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + } + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl.get(), id, true); + + // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement + const int64_t t_now = ggml_time_us(); + + slot.n_decoded += 1; + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_now; + slot.t_print_last = t_now; + slot.n_decoded_last = 0; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + return; + } + + slot.print_timings_tg(); + }); + + // speculative decoding - main model sample and accept + iterate(slots, [&](server_slot & slot) { + if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { + return; + } + + // save the original draft size + const size_t n_draft = slot.spec_draft.size(); + + GGML_ASSERT(n_draft > 0); + + // verify and try to accept the draft + { + // save the sampler sampler state in case we need to restore it + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + + GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); + slot.spec_i_batch.clear(); + + GGML_ASSERT(accepted.size() >= 1); + + const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); + + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); + + // check for partial draft acceptance + if (n_rollback > 0) { + if (use_ckpt_tgt) { + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); + } + + // partial acceptance is not supported by the context -> truncate the draft and restore the state + slot.spec_draft = std::move(accepted); + + const auto & ckpt = slot.spec_ckpt; + + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); + } + + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); + } + + slot.prompt.tokens.keep_first(ckpt.n_tokens); + slot.smpl = std::move(smpl_save); + + return; + } } - const int tok_idx = slot.i_batch - i; + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); + } - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - slot.i_batch = -1; + slot.spec_draft = std::move(accepted); + } - common_sampler_accept(slot.smpl.get(), id, true); + const int64_t t_now = ggml_time_us(); - // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement - const int64_t t_now = ggml_time_us(); + const auto ids = std::move(slot.spec_draft); + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + slot.n_draft_verif_steps += 1; + + if (slot.n_accepted_per_pos.empty()) { + slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); + } + for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { + slot.n_accepted_per_pos[i]++; + } + + // add accepted tokens to the prompt + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); + + common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); + } + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = t_now; - slot.t_print_last = t_now; - slot.n_decoded_last = 0; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - if (!process_token(result, slot)) { - // release slot because of stop condition slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); slot.release(); - continue; + return; } - - slot.print_timings_tg(); } - // speculative decoding - main model sample and accept - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { - continue; - } + slot.print_timings_tg(); - // save the original draft size - const size_t n_draft = slot.spec_draft.size(); - - GGML_ASSERT(n_draft > 0); - - // verify and try to accept the draft - { - // save the sampler sampler state in case we need to restore it - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); - - GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); - slot.spec_i_batch.clear(); - - GGML_ASSERT(accepted.size() >= 1); - - const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); - - const bool use_ckpt_tgt = - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); - - // check for partial draft acceptance - if (n_rollback > 0) { - if (use_ckpt_tgt) { - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); - } - - // partial acceptance is not supported by the context -> truncate the draft and restore the state - slot.spec_draft = std::move(accepted); - - const auto & ckpt = slot.spec_ckpt; - - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - - { - ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); - } - - if (slot.ctx_dft) { - ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); - } - - slot.prompt.tokens.keep_first(ckpt.n_tokens); - slot.smpl = std::move(smpl_save); - - continue; - } - } - - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); - } - - common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - - slot.spec_draft = std::move(accepted); - } - - const int64_t t_now = ggml_time_us(); - - const auto ids = std::move(slot.spec_draft); - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - slot.n_draft_verif_steps += 1; - - if (slot.n_accepted_per_pos.empty()) { - slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); - } - for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { - slot.n_accepted_per_pos[i]++; - } - - // add accepted tokens to the prompt - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - slot.sampled = ids.back(); // last accepted token - SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - - common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); - if (slot.ctx_dft) { - common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); - } - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - slot.n_decoded += 1; - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - slot.print_timings_tg(); - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); + }); } int get_slot_n_ctx() {