From 8b56d813a9ed04fa7b7fe2588fddd845cf64eccb Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 5 May 2026 08:05:24 +0300 Subject: [PATCH] MTP improvements (#1736) * MTP improvements * Cleanup --- common/speculative.cpp | 70 +++++++++++++++++++++++++++++-------- src/graphs/build_qwen35.cpp | 6 ++-- src/llama.cpp | 6 ++-- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index c7c8d86b..1eee7cee 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -219,7 +220,6 @@ struct common_speculative_state_mtp : public common_speculative_state { } }; - struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; @@ -1213,6 +1213,23 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr } } +struct mtp_last_embd { + std::vector embd; + float prob; + int last_id = -1; +}; + +// Hopefully never called concurrently from multiple threads +static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) { + static std::unordered_map map; + auto & last = map[ctx]; + if (last.embd.empty()) { + auto n_embd = llama_model_n_embd(llama_get_model(ctx)); + last.embd.resize(n_embd); + } + return last; +} + llama_tokens common_speculative_draft( common_speculative * spec, common_params_speculative & params, @@ -1361,7 +1378,7 @@ std::vector mtp_speculative_gen_draft( llama_token id_last, int32_t n_past, llama_seq_id seq_id) { - + llama_tokens drafts; drafts.reserve(n_draft); @@ -1372,12 +1389,28 @@ std::vector mtp_speculative_gen_draft( llama_batch mtp_batch = llama_batch_init(1, 0, 1); llama_set_mtp_op_type(ctx, MTP_OP_DRAFT_GEN); + float prob; + auto prob_ptr = p_min > 0 ? &prob : nullptr; + llama_token current_input_id = id_last; int32_t current_n_past = n_past; const int n_embd = llama_model_n_embd(llama_get_model(ctx)); - std::vector draft_hidden_state(n_embd); - for (int i = 0; i < n_draft; ++i) { + auto & last = mtp_get_last_embd(ctx); + int i0 = 0; + if (last.last_id >= 0) { + if (last.prob < p_min) { + return drafts; + } + current_input_id = last.last_id; + last.last_id = -1; + drafts.push_back(current_input_id); + current_n_past++; + llama_set_draft_input_hidden_state(ctx, last.embd.data()); + i0 = 1; + } + + for (int i = i0; i < n_draft; ++i) { mtp_batch.n_tokens = 0; common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true); @@ -1385,8 +1418,10 @@ std::vector mtp_speculative_gen_draft( break; } - float prob; - llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, &prob); + llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0, prob_ptr); + if (i > 0 && prob_ptr && prob < p_min) { + return drafts; + } drafts.push_back(id_next); @@ -1394,15 +1429,15 @@ std::vector mtp_speculative_gen_draft( if (!emb) { break; } - + // Keep a stable copy because later decode steps reuse ctx->embd storage. - memcpy(draft_hidden_state.data(), emb, n_embd * sizeof(float)); - llama_set_draft_input_hidden_state(ctx, draft_hidden_state.data()); + memcpy(last.embd.data(), emb, n_embd * sizeof(float)); + llama_set_draft_input_hidden_state(ctx, last.embd.data()); current_input_id = id_next; current_n_past++; - if (prob < p_min) { + if (prob_ptr && prob < p_min) { break; } } @@ -1431,8 +1466,8 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1); } - LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens from pos %d...\n", - is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, (int)start_pos); + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens for seq_id %d from pos %d...\n", + is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, seq_id, (int)start_pos); llama_batch mtp_batch = batch; if (is_prompt_warmup) { @@ -1452,8 +1487,7 @@ void mtp_accept_tokens( struct llama_context * ctx, const std::vector & ids, int32_t n_past_base, - llama_seq_id seq_id -) { + llama_seq_id seq_id) { if (ids.empty()) { return; } @@ -1465,5 +1499,13 @@ void mtp_accept_tokens( mtp_update_kv_cache(ctx, accepted_batch, false); + auto & last = mtp_get_last_embd(ctx); + auto embd = llama_get_embeddings_ith(ctx, ids.size() - 1); + if (embd) { + std::memcpy(last.embd.data(), embd, last.embd.size()*sizeof(float)); + llama_set_draft_input_hidden_state(ctx, last.embd.data()); + last.last_id = common_sampler_sample_speculative(nullptr, ctx, ids.size() - 1, &last.prob); + } + llama_batch_free(accepted_batch); } diff --git a/src/graphs/build_qwen35.cpp b/src/graphs/build_qwen35.cpp index fb19d679..2f129177 100644 --- a/src/graphs/build_qwen35.cpp +++ b/src/graphs/build_qwen35.cpp @@ -147,8 +147,8 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp( struct ggml_tensor * prev_embeddings, int64_t n_embd_head, struct ggml_cgraph * gf, - struct ggml_tensor * inp_pos -) { + struct ggml_tensor * inp_pos) { + const int il = hparams.n_layer - 1; struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -217,4 +217,4 @@ struct ggml_tensor * llm_build_context::build_qwen35_mtp( cb(cur, "result_output", -1); return cur; -} \ No newline at end of file +} diff --git a/src/llama.cpp b/src/llama.cpp index 28121cf7..8edd580c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4590,8 +4590,7 @@ static int llama_decode_internal( tim1 = ggml_time_us(); #endif // Do not process logits if MTP is only updating the KV cache. - if (cparams.mtp_op_type != MTP_OP_WARMUP && - cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { + if (cparams.mtp_op_type != MTP_OP_WARMUP) { // && cparams.mtp_op_type != MTP_OP_UPDATE_ACCEPTED) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(lctx.logits != nullptr); @@ -4627,7 +4626,8 @@ static int llama_decode_internal( } // extract embeddings - if (embd && (cparams.mtp_op_type == MTP_OP_NONE || cparams.mtp_op_type == MTP_OP_DRAFT_GEN)) { + //if (embd && (cparams.mtp_op_type == MTP_OP_NONE || cparams.mtp_op_type == MTP_OP_DRAFT_GEN)) { + if (embd && cparams.mtp_op_type != MTP_OP_WARMUP) { #if IK_PRINT_TIMING tim1 = ggml_time_us(); #endif