Minor MTP improvement

This commit is contained in:
Kawrakow 2026-05-04 05:52:04 +00:00
parent 38c200373f
commit b2c9fd1524

View File

@ -1361,7 +1361,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id) {
llama_tokens drafts;
drafts.reserve(n_draft);
@ -1377,6 +1377,9 @@ std::vector<llama_token> mtp_speculative_gen_draft(
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
std::vector<float> draft_hidden_state(n_embd);
float prob;
auto prob_ptr = p_min > 0 ? &prob : nullptr;
for (int i = 0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
@ -1385,16 +1388,19 @@ std::vector<llama_token> mtp_speculative_gen_draft(
break;
}
float prob;
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, &prob);
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
drafts.push_back(id_next);
if (i > 0 && prob_ptr && prob < p_min) { // i.e., generate at least one draft even if prob < p_min
break;
}
const float * emb = llama_get_embeddings_ith(ctx, 0);
if (!emb) {
break;
}
drafts.push_back(id_next);
// Keep a stable copy because later decode steps reuse ctx->embd storage.
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
@ -1402,7 +1408,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
current_input_id = id_next;
current_n_past++;
if (prob < p_min) {
if (prob_ptr && prob < p_min) {
break;
}
}