Fix discarding tokens from the KV cache during MTP drafting

This commit is contained in:
Kawrakow 2026-05-08 04:51:59 +00:00
parent 9a26522af2
commit d0c4dd6c55

View File

@ -1410,17 +1410,19 @@ std::vector<llama_token> mtp_speculative_gen_draft(
i0 = 1;
}
int n_decode = 0;
for (int i = i0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
++n_decode;
if (llama_decode(ctx, mtp_batch) != 0) {
break;
}
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
if (i > 0 && prob_ptr && prob < p_min) {
return drafts;
break;
}
drafts.push_back(id_next);
@ -1446,8 +1448,15 @@ std::vector<llama_token> mtp_speculative_gen_draft(
// Purge the metadata for the draft tokens.
// This prevents cache state corruption where two cells map to the same logical position.
if (!drafts.empty()) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
// If the state contained in `last` had a valid token id and probability, it means that we
// have previously run an "accept" batch, where the token sampled from the main model was included.
// In that case, we need to discard all tokens that we ran here to get the KV cache to the correct state.
// => for i0 = 1 we discard from n_past
// But if we did not have a valid last token_id, it means the first token we run was sampled from the
// main model. Hence we want to keep this token in the KV cache and discard all other tokens.
// => for i0 = 0 we discard from n_past + 1
if (n_decode > 0) {
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1 - i0, n_past + n_decode + 2);
}
return drafts;