diff --git a/common/speculative.cpp b/common/speculative.cpp index 7c92e7cb..2341bb6c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -2212,20 +2212,18 @@ int32_t mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos); + // We never need all logits. We only need the logits of the last token so we can sample + // the next draft token. In the MTP_OP_WARMUP case we do not need logits at all, but just + // in case we also get the logits of the last token. llama_batch mtp_batch = batch; + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = false; + } + mtp_batch.logits[mtp_batch.n_tokens-1] = true; if (is_prompt_warmup) { llama_set_mtp_op_type(ctx, MTP_OP_WARMUP); - // We don't need the logits when doing warmup - for (int i = 0; i < mtp_batch.n_tokens; ++i) { - mtp_batch.logits[i] = false; - } - // This is just in case to not run into empty tensor issues - mtp_batch.logits[mtp_batch.n_tokens-1] = true; } else { llama_set_mtp_op_type(ctx, MTP_OP_UPDATE_ACCEPTED); - for (int i = 0; i < mtp_batch.n_tokens; ++i) { - mtp_batch.logits[i] = true; - } } const int32_t ret = llama_decode(ctx, mtp_batch); diff --git a/src/llama.cpp b/src/llama.cpp index 4fa591ea..e65188cf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4980,7 +4980,7 @@ static int llama_decode_internal( } // reserve output buffer - n_outputs_embd = has_mtp ? n_tokens_all : n_outputs; + n_outputs_embd = has_mtp && cparams.mtp_op_type == MTP_OP_NONE ? n_tokens_all : n_outputs; if (llama_output_reserve(lctx, std::max(n_outputs, n_outputs_embd)) < std::max(n_outputs, n_outputs_embd)) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %zu outputs\n", __func__, std::max(n_outputs, n_outputs_embd)); return -2;