speculative: keep MTP draft hidden state alive across steps (#1718)

This commit is contained in:
dmaivel 2026-05-02 06:05:41 -07:00 committed by GitHub
parent a8aecbf159
commit 1b14f56693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1374,6 +1374,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
llama_token current_input_id = id_last;
int32_t current_n_past = n_past;
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
std::vector<float> draft_hidden_state(n_embd);
for (int i = 0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
@ -1389,9 +1391,13 @@ std::vector<llama_token> mtp_speculative_gen_draft(
drafts.push_back(id_next);
const float * emb = llama_get_embeddings_ith(ctx, 0);
if (emb) {
llama_set_draft_input_hidden_state(ctx, emb);
if (!emb) {
break;
}
// Keep a stable copy because later decode steps reuse ctx->embd storage.
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
current_input_id = id_next;
current_n_past++;