server: refactor batch construction

This commit is contained in:
Xuan Son Nguyen 2026-06-20 16:35:57 +02:00
parent e27f308597
commit d5037c508a

View File

@ -63,6 +63,71 @@ enum slot_state {
SLOT_STATE_GENERATING,
};
struct server_slot; // forward declaration
struct server_batch {
llama_batch batch;
struct token {
int32_t id_slot;
llama_token token;
llama_pos pos;
bool output;
};
std::vector<token> tokens;
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
server_batch() {
batch.token = nullptr;
}
server_batch(int32_t n_tokens_alloc) {
batch = llama_batch_init(n_tokens_alloc, 0, 1);
tokens.reserve(n_tokens_alloc);
}
~server_batch() {
llama_batch_free(batch);
}
bool add(int32_t id_slot, llama_token token, llama_pos pos, bool output) {
GGML_ASSERT(batch.token != nullptr);
if (tokens.size() >= tokens.capacity()) {
return false;
}
tokens.push_back({ id_slot, token, pos, output });
return true;
}
void clear() {
tokens.clear();
slot_batched = nullptr;
alora_scale = -1.0f;
alora_disabled_id = 0;
}
int32_t size() const {
return (int32_t)tokens.size();
}
void set_output(int32_t idx, bool output) {
GGML_ASSERT(idx >= 0 && idx < (int32_t)tokens.size());
tokens[idx].output = output;
}
llama_batch render() {
for (int32_t i = 0; i < size(); i++) {
const auto & t = tokens[i];
common_batch_add(batch, t.token, t.pos, { t.id_slot }, t.output);
}
return batch;
}
};
struct server_slot {
int id;
@ -348,12 +413,13 @@ struct server_slot {
return n_draft_max;
}
void update_batch(llama_batch & batch) {
// add sampled token of this slot to the batch, optionally add the speculative draft tokens if any
void handle_last_sampled_token(server_batch & batch) {
if (spec_draft.empty()) {
// no speculative decoding
i_batch = batch.n_tokens;
i_batch = batch.size();
common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true);
batch.add(id, sampled, prompt.tokens.pos_next(), true);
SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
sampled, n_ctx, prompt.n_tokens(), truncated);
@ -363,16 +429,16 @@ struct server_slot {
GGML_ASSERT(spec_i_batch.empty());
spec_i_batch.push_back(batch.n_tokens);
spec_i_batch.push_back(batch.size());
for (size_t i = 0; i < spec_draft.size(); i++) {
spec_i_batch.push_back(batch.n_tokens + i + 1);
spec_i_batch.push_back(batch.size() + i + 1);
}
auto pos0 = prompt.tokens.pos_next();
common_batch_add(batch, sampled, pos0++, { this->id }, true);
batch.add(id, sampled, pos0++, true);
for (auto token : spec_draft) {
common_batch_add(batch, token, pos0++, { this->id }, true);
batch.add(this->id, token, pos0++, true);
}
}
@ -793,7 +859,7 @@ private:
llama_context * ctx_tgt = nullptr;
llama_batch batch {};
server_batch batch;
llama_model_ptr model_dft;
llama_context_ptr ctx_dft;
@ -845,8 +911,6 @@ private:
mtmd_free(mctx);
mctx = nullptr;
llama_batch_free(batch);
}
void handle_sleeping_state(bool new_state) {
@ -1214,7 +1278,7 @@ private:
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx_tgt);
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch = server_batch(std::max(n_batch, params_base.n_parallel));
}
if (params_base.cache_ram_mib != 0) {
@ -2511,19 +2575,39 @@ private:
if (all_idle) {
SRV_INF("%s", "all slots are idle\n");
return; // skip further processing
return;
} else {
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
task.id = queue_tasks.get_new_id();
queue_tasks.post(std::move(task));
}
}
{
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
task.id = queue_tasks.get_new_id();
queue_tasks.post(std::move(task));
llama_batch batch_view;
try {
pre_decode();
batch_view = batch.render();
} catch (const std::exception & e) {
SRV_ERR("pre_decode() failed: %s\n", e.what());
}
try {
decode(llama_n_batch(ctx_tgt), batch_view);
} catch (const std::exception & e) {
SRV_ERR("decode() failed: %s\n", e.what());
}
try {
post_decode(batch_view);
} catch (const std::exception & e) {
SRV_ERR("post_decode() failed: %s\n", e.what());
}
}
void pre_decode() {
// apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) {
@ -2594,10 +2678,10 @@ private:
}
// start populating the batch for this iteration
common_batch_clear(batch);
batch.clear();
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
auto & slot_batched = batch.slot_batched;
std::vector<server_slot *> generating;
std::vector<server_slot *> drafting;
@ -2719,18 +2803,18 @@ private:
for (auto * slot_ptr : generating) {
auto & slot = *slot_ptr;
slot.update_batch(batch);
slot.handle_last_sampled_token(batch);
}
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx_tgt);
int32_t n_ubatch = llama_n_ubatch(ctx_tgt);
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
auto & alora_scale = batch.alora_scale;
auto & alora_disabled_id = batch.alora_disabled_id;
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || batch.size() == 0) {
for (auto & slot : slots) {
if (!slot.is_processing()) {
continue;
@ -2752,7 +2836,7 @@ private:
const auto & input_tokens = slot.task->tokens;
// used to determine the number of tokens added to the batch for the current slot
const auto n_tokens_prev = batch.n_tokens;
const auto n_tokens_prev = batch.size();
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
@ -3048,7 +3132,7 @@ private:
if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
if (batch.size() + slot.task->n_tokens() > n_batch) {
continue;
}
}
@ -3133,7 +3217,7 @@ private:
const bool n_before_user_known = n_before_user > 0;
// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
// get next token to process
llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
if (cur_tok == LLAMA_TOKEN_NULL) {
@ -3151,10 +3235,9 @@ private:
// embedding requires all tokens in the batch to be output;
// MTP also wants logits at every prompt position so the
// streaming hook can mirror t_h_nextn into ctx_dft.
common_batch_add(batch,
batch.add(slot.id,
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.need_embd());
slot.prompt.tokens.push_back(cur_tok);
@ -3190,7 +3273,7 @@ private:
}
// the number of tokens added to the batch for the current slot
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;
const auto n_tokens_cur = batch.size() - n_tokens_prev;
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
@ -3198,13 +3281,13 @@ private:
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(batch.size() > 0);
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.set_output(batch.size() - 1, true);
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = batch.size() - 1;
slot.init_sampler();
} else {
@ -3264,19 +3347,21 @@ private:
slot_batched = &slot;
}
if (batch.n_tokens >= n_batch) {
if (batch.size() >= n_batch) {
break;
}
}
}
}
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
bool decode(int32_t n_batch, llama_batch & batch_view) {
SRV_DBG("decoding batch, n_tokens = %d\n", batch.size());
auto accept_special_token = [&](server_slot & slot, llama_token token) {
return params_base.special ||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};
auto & slot_batched = batch.slot_batched;
auto & alora_scale = batch.alora_scale;
auto & alora_disabled_id = batch.alora_disabled_id;
// TODO @ngxson : alora handling is too messy, need to refactor it to be more clear and maintainable
if (slot_batched) {
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx_tgt, slot_batched->lora);
@ -3291,7 +3376,7 @@ private:
llama_set_embeddings(ctx_tgt, slot_batched->need_embd());
}
if (batch.n_tokens == 0) {
if (batch.size() == 0) {
SRV_WRN("%s", "no tokens to decode\n");
if (++n_empty_consecutive > 3) {
@ -3301,330 +3386,314 @@ private:
n_empty_consecutive = 0;
}
int32_t i_next = 0;
const int ret = llama_decode(ctx_tgt, batch_view);
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
metrics.on_decoded(slots);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
if (ret != 0) {
{
std::string err;
const int ret = llama_decode(ctx_tgt, batch_view);
metrics.on_decoded(slots);
if (ret != 0) {
{
std::string err;
if (n_batch == 1 && ret == 1) {
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
// need to remove the tokens from the current batch too
err = "Context size has been exceeded.";
}
if (ret == -1) {
err = "Invalid input batch.";
}
if (ret < -1) {
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
err = "Compute error.";
}
// TODO: handle ret == 2 (abort) when we start aborting
if (!err.empty()) {
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
// note: it's complicated to keep track of how much of the current batch has been
// processed before the error occurred, so we simply clear the entire context
slot.prompt_clear(false);
}
}
break;
}
if (n_batch == 1 && ret == 1) {
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
// need to remove the tokens from the current batch too
err = "Context size has been exceeded.";
}
// retry with half the batch size to try to find a free slot in the KV cache
if (!try_clear_idle_slots()) {
n_batch /= 2;
if (ret == -1) {
err = "Invalid input batch.";
}
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
if (ret < -1) {
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
err = "Compute error.";
}
continue; // continue loop of n_batch
}
// TODO: handle ret == 2 (abort) when we start aborting
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
// for now, always re-evaluate for simplicity
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
if (!common_speculative_process(spec.get(), batch_view)) {
SRV_ERR("%s", "failed to process speculative batch\n");
if (!err.empty()) {
SRV_ERR("%s n_batch = %d, ret = %d\n", err.c_str(), n_batch, ret);
// TODO: handle error
break;
}
for (auto & slot : slots) {
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
// move the head of the batch forward with the number of tokens we just processed
i_next = i + n_tokens;
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx_tgt);
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
std::vector<server_slot *> children;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
children.push_back(&other);
// note: it's complicated to keep track of how much of the current batch has been
// processed before the error occurred, so we simply clear the entire context
slot.prompt_clear(false);
}
}
// all children slots should already launched by launch_slots_with_parent_task()
// copy state to the child slots
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
// stop, do not retry with smaller batch size
throw std::runtime_error(err);
}
}
for (auto & slot : slots) {
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
// retry with half the batch size to try to find a free slot in the KV cache
if (!try_clear_idle_slots()) {
n_batch /= 2;
// TODO @ngxson : handle sub-batching
}
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, n_batch = %d, ret = %d\n", n_batch, ret);
return false; // retry with smaller batch size
}
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
// for now, always re-evaluate for simplicity
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
if (!common_speculative_process(spec.get(), batch_view)) {
SRV_ERR("%s", "failed to process speculative batch\n");
// TODO: handle error
return false;
}
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
std::vector<server_slot *> children;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
children.push_back(&other);
}
}
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
// all children slots should already launched by launch_slots_with_parent_task()
// copy state to the child slots
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
}
}
return true;
}
void post_decode(llama_batch & batch_view) {
auto accept_special_token = [&](server_slot & slot, llama_token token) {
return params_base.special ||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
};
for (auto & slot : slots) {
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
}
}
// TODO @ngxson : bring back slot.i_batch check when sub-batching is implemented
if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
GGML_ASSERT(slot.task->need_sampling());
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
if (slot.can_speculate() && !slot.spec_draft.empty()) {
continue; // sample using speculative decoding
GGML_ASSERT(slot.task->need_sampling());
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
if (slot.can_speculate() && !slot.spec_draft.empty()) {
continue; // sample using speculative decoding
}
const int tok_idx = slot.i_batch;
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_now = ggml_time_us();
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = t_now;
slot.t_print_last = t_now;
slot.n_decoded_last = 0;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.task->params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
continue;
}
slot.print_timings_tg();
}
// speculative decoding - main model sample and accept
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) {
continue;
}
// save the original draft size
const size_t n_draft = slot.spec_draft.size();
GGML_ASSERT(n_draft > 0);
// verify and try to accept the draft
{
// save the sampler sampler state in case we need to restore it
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
slot.spec_i_batch.clear();
GGML_ASSERT(accepted.size() >= 1);
const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size();
const bool use_ckpt_tgt =
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt));
// check for partial draft acceptance
if (n_rollback > 0) {
if (use_ckpt_tgt) {
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
}
// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);
const auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
{
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
}
if (slot.ctx_dft) {
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
slot.smpl = std::move(smpl_save);
continue;
}
}
const int tok_idx = slot.i_batch - i;
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
}
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
common_speculative_accept(spec.get(), slot.id, accepted.size() - 1);
slot.i_batch = -1;
slot.spec_draft = std::move(accepted);
}
common_sampler_accept(slot.smpl.get(), id, true);
const int64_t t_now = ggml_time_us();
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_now = ggml_time_us();
const auto ids = std::move(slot.spec_draft);
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.n_draft_verif_steps += 1;
if (slot.n_accepted_per_pos.empty()) {
slot.n_accepted_per_pos.resize(common_speculative_n_max(&params_base.speculative), 0);
}
for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) {
slot.n_accepted_per_pos[i]++;
}
// add accepted tokens to the prompt
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_dft) {
common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
// TODO: set result.probs
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = t_now;
slot.t_print_last = t_now;
slot.n_decoded_last = 0;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.task->params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
continue;
break;
}
slot.print_timings_tg();
}
// speculative decoding - main model sample and accept
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) {
continue;
}
slot.print_timings_tg();
// save the original draft size
const size_t n_draft = slot.spec_draft.size();
GGML_ASSERT(n_draft > 0);
// verify and try to accept the draft
{
// save the sampler sampler state in case we need to restore it
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
slot.spec_i_batch.clear();
GGML_ASSERT(accepted.size() >= 1);
const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size();
const bool use_ckpt_tgt =
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt));
// check for partial draft acceptance
if (n_rollback > 0) {
if (use_ckpt_tgt) {
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
}
// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);
const auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
{
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
}
if (slot.ctx_dft) {
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
slot.smpl = std::move(smpl_save);
continue;
}
}
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
}
common_speculative_accept(spec.get(), slot.id, accepted.size() - 1);
slot.spec_draft = std::move(accepted);
}
const int64_t t_now = ggml_time_us();
const auto ids = std::move(slot.spec_draft);
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.n_draft_verif_steps += 1;
if (slot.n_accepted_per_pos.empty()) {
slot.n_accepted_per_pos.resize(common_speculative_n_max(&params_base.speculative), 0);
}
for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) {
slot.n_accepted_per_pos[i]++;
}
// add accepted tokens to the prompt
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_dft) {
common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
// TODO: set result.probs
slot.n_decoded += 1;
if (!process_token(result, slot)) {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
slot.release();
break;
}
}
slot.print_timings_tg();
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
}
SRV_DBG("%s", "run slots completed\n");
}
int get_slot_n_ctx() {