mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
speculative: keep MTP draft hidden state alive across steps (#1718)
This commit is contained in:
parent
a8aecbf159
commit
1b14f56693
@ -1374,6 +1374,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
std::vector<float> draft_hidden_state(n_embd);
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
@ -1389,9 +1391,13 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
drafts.push_back(id_next);
|
||||
|
||||
const float * emb = llama_get_embeddings_ith(ctx, 0);
|
||||
if (emb) {
|
||||
llama_set_draft_input_hidden_state(ctx, emb);
|
||||
if (!emb) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Keep a stable copy because later decode steps reuse ctx->embd storage.
|
||||
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user