handle batch full more carefully

This commit is contained in:
Xuan Son Nguyen 2026-06-20 19:30:59 +02:00
parent d704c7929b
commit 64ec03d10b

View File

@ -442,11 +442,12 @@ struct server_slot {
// add sampled token of this slot to the batch, optionally add the speculative draft tokens if any
void handle_last_sampled_token(server_batch & batch) {
bool add_ok = true;
if (spec_draft.empty()) {
// no speculative decoding
i_batch = batch.size();
batch.add(id, sampled, prompt.tokens.pos_next(), true);
add_ok &= batch.add(id, sampled, prompt.tokens.pos_next(), true);
SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
sampled, n_ctx, prompt.n_tokens(), truncated);
@ -463,12 +464,14 @@ struct server_slot {
auto pos0 = prompt.tokens.pos_next();
batch.add(id, sampled, pos0++, true);
add_ok &= batch.add(id, sampled, pos0++, true);
for (auto token : spec_draft) {
batch.add(this->id, token, pos0++, true);
add_ok &= batch.add(this->id, token, pos0++, true);
}
}
GGML_ASSERT(add_ok && "batch must be large enough to hold the sampled and draft tokens");
prompt.tokens.push_back(sampled);
prompt.tokens.insert(spec_draft);
}
@ -2679,6 +2682,7 @@ private:
} catch (const std::exception & e) {
SRV_ERR("decode() failed: %s\n", e.what());
abort_all_slots("decode() failed: " + std::string(e.what()));
break; // stop any further processing
}
try {
@ -2686,11 +2690,9 @@ private:
} catch (const std::exception & e) {
SRV_ERR("post_decode() failed: %s\n", e.what());
abort_all_slots("post_decode() failed: " + std::string(e.what()));
break; // stop any further processing
}
if (batch.size() >= n_batch) {
break;
}
}
}
@ -2898,7 +2900,13 @@ private:
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.size() == 0) {
bool add_ok = true;
iterate(slots, [&](server_slot & slot) {
if (!add_ok || batch.size() >= n_batch) {
return; // batch is full, skip remaining slots
}
if (!slot.is_processing()) {
return;
}
@ -3318,7 +3326,7 @@ private:
// embedding requires all tokens in the batch to be output;
// MTP also wants logits at every prompt position so the
// streaming hook can mirror t_h_nextn into ctx_dft.
batch.add(slot.id,
add_ok &= batch.add(slot.id,
cur_tok,
slot.prompt.tokens.pos_next(),
slot.need_embd());