mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server: refactor batch construction
This commit is contained in:
parent
e27f308597
commit
d5037c508a
@ -63,6 +63,71 @@ enum slot_state {
|
||||
SLOT_STATE_GENERATING,
|
||||
};
|
||||
|
||||
struct server_slot; // forward declaration
|
||||
|
||||
struct server_batch {
|
||||
llama_batch batch;
|
||||
|
||||
struct token {
|
||||
int32_t id_slot;
|
||||
llama_token token;
|
||||
llama_pos pos;
|
||||
bool output;
|
||||
};
|
||||
std::vector<token> tokens;
|
||||
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
|
||||
float alora_scale = -1.0f;
|
||||
size_t alora_disabled_id = 0;
|
||||
|
||||
server_batch() {
|
||||
batch.token = nullptr;
|
||||
}
|
||||
server_batch(int32_t n_tokens_alloc) {
|
||||
batch = llama_batch_init(n_tokens_alloc, 0, 1);
|
||||
tokens.reserve(n_tokens_alloc);
|
||||
}
|
||||
|
||||
~server_batch() {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
bool add(int32_t id_slot, llama_token token, llama_pos pos, bool output) {
|
||||
GGML_ASSERT(batch.token != nullptr);
|
||||
if (tokens.size() >= tokens.capacity()) {
|
||||
return false;
|
||||
}
|
||||
tokens.push_back({ id_slot, token, pos, output });
|
||||
return true;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
tokens.clear();
|
||||
slot_batched = nullptr;
|
||||
alora_scale = -1.0f;
|
||||
alora_disabled_id = 0;
|
||||
}
|
||||
|
||||
int32_t size() const {
|
||||
return (int32_t)tokens.size();
|
||||
}
|
||||
|
||||
void set_output(int32_t idx, bool output) {
|
||||
GGML_ASSERT(idx >= 0 && idx < (int32_t)tokens.size());
|
||||
tokens[idx].output = output;
|
||||
}
|
||||
|
||||
llama_batch render() {
|
||||
for (int32_t i = 0; i < size(); i++) {
|
||||
const auto & t = tokens[i];
|
||||
common_batch_add(batch, t.token, t.pos, { t.id_slot }, t.output);
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
|
||||
@ -348,12 +413,13 @@ struct server_slot {
|
||||
return n_draft_max;
|
||||
}
|
||||
|
||||
void update_batch(llama_batch & batch) {
|
||||
// add sampled token of this slot to the batch, optionally add the speculative draft tokens if any
|
||||
void handle_last_sampled_token(server_batch & batch) {
|
||||
if (spec_draft.empty()) {
|
||||
// no speculative decoding
|
||||
i_batch = batch.n_tokens;
|
||||
i_batch = batch.size();
|
||||
|
||||
common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true);
|
||||
batch.add(id, sampled, prompt.tokens.pos_next(), true);
|
||||
|
||||
SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||
sampled, n_ctx, prompt.n_tokens(), truncated);
|
||||
@ -363,16 +429,16 @@ struct server_slot {
|
||||
|
||||
GGML_ASSERT(spec_i_batch.empty());
|
||||
|
||||
spec_i_batch.push_back(batch.n_tokens);
|
||||
spec_i_batch.push_back(batch.size());
|
||||
for (size_t i = 0; i < spec_draft.size(); i++) {
|
||||
spec_i_batch.push_back(batch.n_tokens + i + 1);
|
||||
spec_i_batch.push_back(batch.size() + i + 1);
|
||||
}
|
||||
|
||||
auto pos0 = prompt.tokens.pos_next();
|
||||
|
||||
common_batch_add(batch, sampled, pos0++, { this->id }, true);
|
||||
batch.add(id, sampled, pos0++, true);
|
||||
for (auto token : spec_draft) {
|
||||
common_batch_add(batch, token, pos0++, { this->id }, true);
|
||||
batch.add(this->id, token, pos0++, true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -793,7 +859,7 @@ private:
|
||||
|
||||
llama_context * ctx_tgt = nullptr;
|
||||
|
||||
llama_batch batch {};
|
||||
server_batch batch;
|
||||
|
||||
llama_model_ptr model_dft;
|
||||
llama_context_ptr ctx_dft;
|
||||
@ -845,8 +911,6 @@ private:
|
||||
|
||||
mtmd_free(mctx);
|
||||
mctx = nullptr;
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
void handle_sleeping_state(bool new_state) {
|
||||
@ -1214,7 +1278,7 @@ private:
|
||||
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
||||
{
|
||||
const int32_t n_batch = llama_n_batch(ctx_tgt);
|
||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
batch = server_batch(std::max(n_batch, params_base.n_parallel));
|
||||
}
|
||||
|
||||
if (params_base.cache_ram_mib != 0) {
|
||||
@ -2511,19 +2575,39 @@ private:
|
||||
|
||||
if (all_idle) {
|
||||
SRV_INF("%s", "all slots are idle\n");
|
||||
return; // skip further processing
|
||||
|
||||
return;
|
||||
} else {
|
||||
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
||||
task.id = queue_tasks.get_new_id();
|
||||
queue_tasks.post(std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
||||
task.id = queue_tasks.get_new_id();
|
||||
queue_tasks.post(std::move(task));
|
||||
llama_batch batch_view;
|
||||
try {
|
||||
pre_decode();
|
||||
batch_view = batch.render();
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("pre_decode() failed: %s\n", e.what());
|
||||
}
|
||||
|
||||
try {
|
||||
decode(llama_n_batch(ctx_tgt), batch_view);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("decode() failed: %s\n", e.what());
|
||||
}
|
||||
|
||||
try {
|
||||
post_decode(batch_view);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("post_decode() failed: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void pre_decode() {
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
for (server_slot & slot : slots) {
|
||||
@ -2594,10 +2678,10 @@ private:
|
||||
}
|
||||
|
||||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
batch.clear();
|
||||
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
auto & slot_batched = batch.slot_batched;
|
||||
|
||||
std::vector<server_slot *> generating;
|
||||
std::vector<server_slot *> drafting;
|
||||
@ -2719,18 +2803,18 @@ private:
|
||||
for (auto * slot_ptr : generating) {
|
||||
auto & slot = *slot_ptr;
|
||||
|
||||
slot.update_batch(batch);
|
||||
slot.handle_last_sampled_token(batch);
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = llama_n_batch(ctx_tgt);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx_tgt);
|
||||
|
||||
float alora_scale = -1.0f;
|
||||
size_t alora_disabled_id = 0;
|
||||
auto & alora_scale = batch.alora_scale;
|
||||
auto & alora_disabled_id = batch.alora_disabled_id;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
if (params_base.cont_batching || batch.size() == 0) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing()) {
|
||||
continue;
|
||||
@ -2752,7 +2836,7 @@ private:
|
||||
const auto & input_tokens = slot.task->tokens;
|
||||
|
||||
// used to determine the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_prev = batch.n_tokens;
|
||||
const auto n_tokens_prev = batch.size();
|
||||
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
@ -3048,7 +3132,7 @@ private:
|
||||
|
||||
if (!slot.can_split()) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
|
||||
if (batch.size() + slot.task->n_tokens() > n_batch) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -3133,7 +3217,7 @@ private:
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
|
||||
// get next token to process
|
||||
llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
|
||||
if (cur_tok == LLAMA_TOKEN_NULL) {
|
||||
@ -3151,10 +3235,9 @@ private:
|
||||
// embedding requires all tokens in the batch to be output;
|
||||
// MTP also wants logits at every prompt position so the
|
||||
// streaming hook can mirror t_h_nextn into ctx_dft.
|
||||
common_batch_add(batch,
|
||||
batch.add(slot.id,
|
||||
cur_tok,
|
||||
slot.prompt.tokens.pos_next(),
|
||||
{ slot.id },
|
||||
slot.need_embd());
|
||||
slot.prompt.tokens.push_back(cur_tok);
|
||||
|
||||
@ -3190,7 +3273,7 @@ private:
|
||||
}
|
||||
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;
|
||||
const auto n_tokens_cur = batch.size() - n_tokens_prev;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
@ -3198,13 +3281,13 @@ private:
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
GGML_ASSERT(batch.size() > 0);
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.set_output(batch.size() - 1, true);
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
slot.i_batch = batch.size() - 1;
|
||||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
@ -3264,19 +3347,21 @@ private:
|
||||
slot_batched = &slot;
|
||||
}
|
||||
|
||||
if (batch.n_tokens >= n_batch) {
|
||||
if (batch.size() >= n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||
bool decode(int32_t n_batch, llama_batch & batch_view) {
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.size());
|
||||
|
||||
auto accept_special_token = [&](server_slot & slot, llama_token token) {
|
||||
return params_base.special ||
|
||||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
|
||||
};
|
||||
auto & slot_batched = batch.slot_batched;
|
||||
auto & alora_scale = batch.alora_scale;
|
||||
auto & alora_disabled_id = batch.alora_disabled_id;
|
||||
|
||||
// TODO @ngxson : alora handling is too messy, need to refactor it to be more clear and maintainable
|
||||
if (slot_batched) {
|
||||
// apply lora, only need to do it once per batch
|
||||
common_set_adapter_lora(ctx_tgt, slot_batched->lora);
|
||||
@ -3291,7 +3376,7 @@ private:
|
||||
llama_set_embeddings(ctx_tgt, slot_batched->need_embd());
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
if (batch.size() == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
|
||||
if (++n_empty_consecutive > 3) {
|
||||
@ -3301,330 +3386,314 @@ private:
|
||||
n_empty_consecutive = 0;
|
||||
}
|
||||
|
||||
int32_t i_next = 0;
|
||||
const int ret = llama_decode(ctx_tgt, batch_view);
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
metrics.on_decoded(slots);
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
if (ret != 0) {
|
||||
{
|
||||
std::string err;
|
||||
|
||||
const int ret = llama_decode(ctx_tgt, batch_view);
|
||||
|
||||
metrics.on_decoded(slots);
|
||||
|
||||
if (ret != 0) {
|
||||
{
|
||||
std::string err;
|
||||
|
||||
if (n_batch == 1 && ret == 1) {
|
||||
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
|
||||
// need to remove the tokens from the current batch too
|
||||
err = "Context size has been exceeded.";
|
||||
}
|
||||
|
||||
if (ret == -1) {
|
||||
err = "Invalid input batch.";
|
||||
}
|
||||
|
||||
if (ret < -1) {
|
||||
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
||||
err = "Compute error.";
|
||||
}
|
||||
|
||||
// TODO: handle ret == 2 (abort) when we start aborting
|
||||
|
||||
if (!err.empty()) {
|
||||
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (slot.is_processing()) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
slot.prompt_clear(false);
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
if (n_batch == 1 && ret == 1) {
|
||||
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
|
||||
// need to remove the tokens from the current batch too
|
||||
err = "Context size has been exceeded.";
|
||||
}
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
if (!try_clear_idle_slots()) {
|
||||
n_batch /= 2;
|
||||
if (ret == -1) {
|
||||
err = "Invalid input batch.";
|
||||
}
|
||||
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
if (ret < -1) {
|
||||
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
||||
err = "Compute error.";
|
||||
}
|
||||
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
// TODO: handle ret == 2 (abort) when we start aborting
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
// for now, always re-evaluate for simplicity
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
|
||||
if (!common_speculative_process(spec.get(), batch_view)) {
|
||||
SRV_ERR("%s", "failed to process speculative batch\n");
|
||||
if (!err.empty()) {
|
||||
SRV_ERR("%s n_batch = %d, ret = %d\n", err.c_str(), n_batch, ret);
|
||||
|
||||
// TODO: handle error
|
||||
break;
|
||||
}
|
||||
for (auto & slot : slots) {
|
||||
if (slot.is_processing()) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
||||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx_tgt);
|
||||
|
||||
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
|
||||
std::vector<server_slot *> children;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
children.push_back(&other);
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
slot.prompt_clear(false);
|
||||
}
|
||||
}
|
||||
|
||||
// all children slots should already launched by launch_slots_with_parent_task()
|
||||
// copy state to the child slots
|
||||
for (auto & child : children) {
|
||||
SLT_INF(slot, " - copying state to child %d\n", child->id);
|
||||
|
||||
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
|
||||
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
// stop, do not retry with smaller batch size
|
||||
throw std::runtime_error(err);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
if (!try_clear_idle_slots()) {
|
||||
n_batch /= 2;
|
||||
// TODO @ngxson : handle sub-batching
|
||||
}
|
||||
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||
|
||||
return false; // retry with smaller batch size
|
||||
}
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
// for now, always re-evaluate for simplicity
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384
|
||||
if (!common_speculative_process(spec.get(), batch_view)) {
|
||||
SRV_ERR("%s", "failed to process speculative batch\n");
|
||||
|
||||
// TODO: handle error
|
||||
return false;
|
||||
}
|
||||
|
||||
// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) {
|
||||
std::vector<server_slot *> children;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
children.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
||||
// all children slots should already launched by launch_slots_with_parent_task()
|
||||
// copy state to the child slots
|
||||
for (auto & child : children) {
|
||||
SLT_INF(slot, " - copying state to child %d\n", child->id);
|
||||
|
||||
GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);
|
||||
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void post_decode(llama_batch & batch_view) {
|
||||
auto accept_special_token = [&](server_slot & slot, llama_token token) {
|
||||
return params_base.special ||
|
||||
slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end();
|
||||
};
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO @ngxson : bring back slot.i_batch check when sub-batching is implemented
|
||||
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
|
||||
// prompt evaluated for embedding
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
|
||||
// prompt evaluated for embedding
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.task->need_sampling());
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
if (slot.can_speculate()) {
|
||||
common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens());
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.can_speculate() && !slot.spec_draft.empty()) {
|
||||
continue; // sample using speculative decoding
|
||||
GGML_ASSERT(slot.task->need_sampling());
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
if (slot.can_speculate()) {
|
||||
common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens());
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.can_speculate() && !slot.spec_draft.empty()) {
|
||||
continue; // sample using speculative decoding
|
||||
}
|
||||
|
||||
const int tok_idx = slot.i_batch;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
||||
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
|
||||
const int64_t t_now = ggml_time_us();
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_generation = t_now;
|
||||
slot.t_print_last = t_now;
|
||||
slot.n_decoded_last = 0;
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
slot.release();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.print_timings_tg();
|
||||
}
|
||||
|
||||
// speculative decoding - main model sample and accept
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// save the original draft size
|
||||
const size_t n_draft = slot.spec_draft.size();
|
||||
|
||||
GGML_ASSERT(n_draft > 0);
|
||||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
// save the sampler sampler state in case we need to restore it
|
||||
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
|
||||
slot.spec_i_batch.clear();
|
||||
|
||||
GGML_ASSERT(accepted.size() >= 1);
|
||||
|
||||
const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size();
|
||||
|
||||
const bool use_ckpt_tgt =
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
|
||||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt));
|
||||
|
||||
// check for partial draft acceptance
|
||||
if (n_rollback > 0) {
|
||||
if (use_ckpt_tgt) {
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
||||
}
|
||||
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
const auto & ckpt = slot.spec_ckpt;
|
||||
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
{
|
||||
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
||||
slot.smpl = std::move(smpl_save);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
|
||||
}
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx);
|
||||
common_speculative_accept(spec.get(), slot.id, accepted.size() - 1);
|
||||
|
||||
slot.i_batch = -1;
|
||||
slot.spec_draft = std::move(accepted);
|
||||
}
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
const int64_t t_now = ggml_time_us();
|
||||
|
||||
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
|
||||
const int64_t t_now = ggml_time_us();
|
||||
const auto ids = std::move(slot.spec_draft);
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
slot.n_draft_verif_steps += 1;
|
||||
|
||||
if (slot.n_accepted_per_pos.empty()) {
|
||||
slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0);
|
||||
}
|
||||
for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) {
|
||||
slot.n_accepted_per_pos[i]++;
|
||||
}
|
||||
|
||||
// add accepted tokens to the prompt
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
if (slot.ctx_dft) {
|
||||
common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_generation = t_now;
|
||||
slot.t_print_last = t_now;
|
||||
slot.n_decoded_last = 0;
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
slot.release();
|
||||
|
||||
continue;
|
||||
break;
|
||||
}
|
||||
|
||||
slot.print_timings_tg();
|
||||
}
|
||||
|
||||
// speculative decoding - main model sample and accept
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) {
|
||||
continue;
|
||||
}
|
||||
slot.print_timings_tg();
|
||||
|
||||
// save the original draft size
|
||||
const size_t n_draft = slot.spec_draft.size();
|
||||
|
||||
GGML_ASSERT(n_draft > 0);
|
||||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
// save the sampler sampler state in case we need to restore it
|
||||
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
|
||||
slot.spec_i_batch.clear();
|
||||
|
||||
GGML_ASSERT(accepted.size() >= 1);
|
||||
|
||||
const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size();
|
||||
|
||||
const bool use_ckpt_tgt =
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
|
||||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt));
|
||||
|
||||
// check for partial draft acceptance
|
||||
if (n_rollback > 0) {
|
||||
if (use_ckpt_tgt) {
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
||||
}
|
||||
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
const auto & ckpt = slot.spec_ckpt;
|
||||
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
{
|
||||
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
||||
slot.smpl = std::move(smpl_save);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft);
|
||||
}
|
||||
|
||||
common_speculative_accept(spec.get(), slot.id, accepted.size() - 1);
|
||||
|
||||
slot.spec_draft = std::move(accepted);
|
||||
}
|
||||
|
||||
const int64_t t_now = ggml_time_us();
|
||||
|
||||
const auto ids = std::move(slot.spec_draft);
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_now - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
slot.n_draft_verif_steps += 1;
|
||||
|
||||
if (slot.n_accepted_per_pos.empty()) {
|
||||
slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0);
|
||||
}
|
||||
for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) {
|
||||
slot.n_accepted_per_pos[i]++;
|
||||
}
|
||||
|
||||
// add accepted tokens to the prompt
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
if (slot.ctx_dft) {
|
||||
common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok));
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
metrics.on_prediction(slot);
|
||||
slot.release();
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
slot.print_timings_tg();
|
||||
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
|
||||
}
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
|
||||
}
|
||||
|
||||
SRV_DBG("%s", "run slots completed\n");
|
||||
}
|
||||
|
||||
int get_slot_n_ctx() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user