mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Minor MTP improvement
This commit is contained in:
parent
38c200373f
commit
b2c9fd1524
@ -1361,7 +1361,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
|
||||
@ -1377,6 +1377,9 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
||||
std::vector<float> draft_hidden_state(n_embd);
|
||||
|
||||
float prob;
|
||||
auto prob_ptr = p_min > 0 ? &prob : nullptr;
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
@ -1385,16 +1388,19 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
float prob;
|
||||
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, &prob);
|
||||
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
|
||||
|
||||
drafts.push_back(id_next);
|
||||
if (i > 0 && prob_ptr && prob < p_min) { // i.e., generate at least one draft even if prob < p_min
|
||||
break;
|
||||
}
|
||||
|
||||
const float * emb = llama_get_embeddings_ith(ctx, 0);
|
||||
if (!emb) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
drafts.push_back(id_next);
|
||||
|
||||
// Keep a stable copy because later decode steps reuse ctx->embd storage.
|
||||
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
|
||||
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
|
||||
@ -1402,7 +1408,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
|
||||
if (prob < p_min) {
|
||||
if (prob_ptr && prob < p_min) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user