MTP improvements (#1736)

* MTP improvements

* Cleanup
This commit is contained in:
Kawrakow 2026-05-05 08:05:24 +03:00 committed by GitHub
parent 45dfd80371
commit 8b56d813a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 20 deletions

View File

@ -14,6 +14,7 @@
#include <cstring>
#include <iomanip>
#include <map>
#include <unordered_map>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@ -219,7 +220,6 @@ struct common_speculative_state_mtp : public common_speculative_state {
}
};
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
@ -1213,6 +1213,23 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
}
}
struct mtp_last_embd {
std::vector<float> embd;
float prob;
int last_id = -1;
};
// Hopefully never called concurrently from multiple threads
static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) {
static std::unordered_map<const llama_context *, mtp_last_embd> map;
auto & last = map[ctx];
if (last.embd.empty()) {
auto n_embd = llama_model_n_embd(llama_get_model(ctx));
last.embd.resize(n_embd);
}
return last;
}
llama_tokens common_speculative_draft(
common_speculative * spec,
common_params_speculative & params,
@ -1361,7 +1378,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id) {
llama_tokens drafts;
drafts.reserve(n_draft);
@ -1372,12 +1389,28 @@ std::vector<llama_token> mtp_speculative_gen_draft(
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN);
float prob;
auto prob_ptr = p_min > 0 ? &prob : nullptr;
llama_token current_input_id = id_last;
int32_t current_n_past = n_past;
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
std::vector<float> draft_hidden_state(n_embd);
for (int i = 0; i < n_draft; ++i) {
auto & last = mtp_get_last_embd(ctx);
int i0 = 0;
if (last.last_id >= 0) {
if (last.prob < p_min) {
return drafts;
}
current_input_id = last.last_id;
last.last_id = -1;
drafts.push_back(current_input_id);
current_n_past++;
llama_set_draft_input_hidden_state(ctx, last.embd.data());
i0 = 1;
}
for (int i = i0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
@ -1385,8 +1418,10 @@ std::vector<llama_token> mtp_speculative_gen_draft(
break;
}
float prob;
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, &prob);
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr);
if (i > 0 && prob_ptr && prob < p_min) {
return drafts;
}
drafts.push_back(id_next);
@ -1394,15 +1429,15 @@ std::vector<llama_token> mtp_speculative_gen_draft(
if (!emb) {
break;
}
// Keep a stable copy because later decode steps reuse ctx->embd storage.
memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data());
memcpy(last.embd.data(), emb, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(ctx, last.embd.data());
current_input_id = id_next;
current_n_past++;
if (prob < p_min) {
if (prob_ptr && prob < p_min) {
break;
}
}
@ -1431,8 +1466,8 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1);
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens from pos %d...\n",
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, (int)start_pos);
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n",
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos);
llama_batch mtp_batch = batch;
if (is_prompt_warmup) {
@ -1452,8 +1487,7 @@ void mtp_accept_tokens(
struct llama_context * ctx,
const std::vector<llama_token> & ids,
int32_t n_past_base,
llama_seq_id seq_id
) {
llama_seq_id seq_id) {
if (ids.empty()) {
return;
}
@ -1465,5 +1499,13 @@ void mtp_accept_tokens(
mtp_update_kv_cache(ctx, accepted_batch, false);
auto & last = mtp_get_last_embd(ctx);
auto embd = llama_get_embeddings_ith(ctx, ids.size() - 1);
if (embd) {
std::memcpy(last.embd.data(), embd, last.embd.size()*sizeof(float));
llama_set_draft_input_hidden_state(ctx, last.embd.data());
last.last_id = common_sampler_sample_speculative(nullptr, ctx, ids.size() - 1, &last.prob);
}
llama_batch_free(accepted_batch);
}

View File

@ -147,8 +147,8 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
struct ggml_tensor * prev_embeddings,
int64_t n_embd_head,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos
) {
struct ggml_tensor * inp_pos) {
const int il = hparams.n_layer - 1;
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@ -217,4 +217,4 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp(
cb(cur, "result_output", -1);
return cur;
}
}

View File

@ -4590,8 +4590,7 @@ static int llama_decode_internal(
tim1 = ggml_time_us();
#endif
// Do not process logits if MTP is only updating the KV cache.
if (cparams.mtp_op_type != MTP_OP_WARMUP &&
cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
@ -4627,7 +4626,8 @@ static int llama_decode_internal(
}
// extract embeddings
if (embd && (cparams.mtp_op_type == MTP_OP_NONE || cparams.mtp_op_type == MTP_OP_DRAFT_GEN)) {
//if (embd && (cparams.mtp_op_type == MTP_OP_NONE || cparams.mtp_op_type == MTP_OP_DRAFT_GEN)) {
if (embd && cparams.mtp_op_type != MTP_OP_WARMUP) {
#if IK_PRINT_TIMING
tim1 = ggml_time_us();
#endif