MTP tweaks 3

This commit is contained in:
Kawrakow 2026-05-22 09:37:36 +00:00
parent b3d39cff8b
commit 8bf4e6ca50
2 changed files with 8 additions and 10 deletions

View File

@ -2212,20 +2212,18 @@ int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n",
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos);
// We never need all logits. We only need the logits of the last token so we can sample
// the next draft token. In the MTP_OP_WARMUP case we do not need logits at all, but just
// in case we also get the logits of the last token.
llama_batch mtp_batch = batch;
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = false;
}
mtp_batch.logits[mtp_batch.n_tokens-1] = true;
if (is_prompt_warmup) {
llama_set_mtp_op_type(ctx, MTP_OP_WARMUP);
// We don't need the logits when doing warmup
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = false;
}
// This is just in case to not run into empty tensor issues
mtp_batch.logits[mtp_batch.n_tokens-1] = true;
} else {
llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED);
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
mtp_batch.logits[i] = true;
}
}
const int32_t ret = llama_decode(ctx, mtp_batch);

View File

@ -4980,7 +4980,7 @@ static int llama_decode_internal(
}
// reserve output buffer
n_outputs_embd = has_mtp ? n_tokens_all : n_outputs;
n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs;
if (llama_output_reserve(lctx, std::max<size_t>(n_outputs, n_outputs_embd)) < std::max<size_t>(n_outputs, n_outputs_embd)) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, std::max<size_t>(n_outputs, n_outputs_embd));
return -2;