From 9f60de9cc5ff446ce5d5e7a6d2dac2c5a5d503f3 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 9 May 2026 08:31:25 +0300 Subject: [PATCH] Fix discarding tokens from the KV cache during MTP drafting (#1757) --- common/speculative.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 2e142b25..d2685067 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1410,17 +1410,19 @@ std::vector mtp_speculative_gen_draft( i0 = 1; } + int n_decode = 0; for (int i = i0; i < n_draft; ++i) { mtp_batch.n_tokens = 0; common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true); + ++n_decode; if (llama_decode(ctx, mtp_batch) != 0) { break; } llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr); if (i > 0 && prob_ptr && prob < p_min) { - return drafts; + break; } drafts.push_back(id_next); @@ -1446,8 +1448,15 @@ std::vector mtp_speculative_gen_draft( // Purge the metadata for the draft tokens. // This prevents cache state corruption where two cells map to the same logical position. - if (!drafts.empty()) { - llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past); + // If the state contained in `last` had a valid token id and probability, it means that we + // have previously run an "accept" batch, where the token sampled from the main model was included. + // In that case, we need to discard all tokens that we ran here to get the KV cache to the correct state. + // => for i0 = 1 we discard from n_past + // But if we did not have a valid last token_id, it means the first token we run was sampled from the + // main model. Hence we want to keep this token in the KV cache and discard all other tokens. + // => for i0 = 0 we discard from n_past + 1 + if (n_decode > 0) { + llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1 - i0, n_past + n_decode + 2); } return drafts;