mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix discarding tokens from the KV cache during MTP drafting (#1757)
This commit is contained in:
parent
98950267c6
commit
9f60de9cc5
@ -1410,17 +1410,19 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
i0 = 1;
|
||||
}
|
||||
|
||||
int n_decode = 0;
|
||||
for (int i = i0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
|
||||
++n_decode;
|
||||
if (llama_decode(ctx, mtp_batch) != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
|
||||
if (i > 0 && prob_ptr && prob < p_min) {
|
||||
return drafts;
|
||||
break;
|
||||
}
|
||||
|
||||
drafts.push_back(id_next);
|
||||
@ -1446,8 +1448,15 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
|
||||
// Purge the metadata for the draft tokens.
|
||||
// This prevents cache state corruption where two cells map to the same logical position.
|
||||
if (!drafts.empty()) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
|
||||
// If the state contained in `last` had a valid token id and probability, it means that we
|
||||
// have previously run an "accept" batch, where the token sampled from the main model was included.
|
||||
// In that case, we need to discard all tokens that we ran here to get the KV cache to the correct state.
|
||||
// => for i0 = 1 we discard from n_past
|
||||
// But if we did not have a valid last token_id, it means the first token we run was sampled from the
|
||||
// main model. Hence we want to keep this token in the KV cache and discard all other tokens.
|
||||
// => for i0 = 0 we discard from n_past + 1
|
||||
if (n_decode > 0) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1 - i0, n_past + n_decode + 2);
|
||||
}
|
||||
|
||||
return drafts;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user