From 557b674f63a617efa7aa0b075d9cf40134b0eeac Mon Sep 17 00:00:00 2001 From: Samuel Oliveira Alves <107287165+SamuelOliveirads@users.noreply.github.com> Date: Thu, 9 Apr 2026 10:33:56 -0300 Subject: [PATCH] Add llama_context to MTP (#1601) * wip: separate llama_context for MTP with graph reuse * wip: fix KV cache desync with separate MTP context * refactor: remove dead mtp logic code, encapsulate KV mirroring * mtp-context: derive args directly from the main model's context * mtp: fix kv cache positions * clean small comments * minor refactor for context shift --- common/speculative.cpp | 67 +++++++++++++++++++++++++++--- common/speculative.h | 11 +++++ examples/server/server-context.cpp | 34 ++++++++++++--- src/llama-context.h | 3 -- src/llama.cpp | 65 ++++++----------------------- 5 files changed, 113 insertions(+), 67 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 4599ff49..d32db0fb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -148,11 +148,13 @@ struct common_speculative_state { struct common_speculative_state_mtp : public common_speculative_state { llama_context * ctx_tgt; + llama_context * ctx_mtp = nullptr; common_sampler * smpl; common_speculative_state_mtp( enum common_speculative_type type, - llama_context * ctx_tgt) + llama_context * ctx_tgt, + const llama_context_params & mtp_cparams) : common_speculative_state(type) , ctx_tgt(ctx_tgt) { @@ -161,10 +163,21 @@ struct common_speculative_state_mtp : public common_speculative_state { llama_sampler_type::DIST, }; smpl = common_sampler_init(llama_get_model(ctx_tgt), params); + + const llama_model * model = llama_get_model(ctx_tgt); + ctx_mtp = llama_init_from_model(const_cast(model), mtp_cparams); + if (ctx_mtp) { + LOG_INF("%s: created MTP context (n_ctx=%d)\n", __func__, llama_n_ctx(ctx_mtp)); + } else { + LOG_ERR("%s: failed to create MTP context, falling back to shared context\n", __func__); + } } ~common_speculative_state_mtp() override { common_sampler_free(smpl); + if (ctx_mtp) { + llama_free(ctx_mtp); + } } void begin(const llama_tokens & prompt) override { @@ -178,12 +191,20 @@ struct common_speculative_state_mtp : public common_speculative_state { llama_tokens & result) override { int32_t n_past = (int32_t)prompt_tgt.size(); - llama_seq_id seq_id = 0; + if (ctx_mtp) { + llama_pos mtp_pos_max = llama_kv_cache_seq_pos_max(ctx_mtp, seq_id); + if (mtp_pos_max >= n_past) { + llama_kv_cache_seq_rm(ctx_mtp, seq_id, n_past, -1); + } + } + + llama_context * ctx = ctx_mtp ? ctx_mtp : ctx_tgt; + result = mtp_speculative_gen_draft( smpl, - ctx_tgt, + ctx, params.n_max, params.p_min, id_last, @@ -954,7 +975,8 @@ common_speculative * common_speculative_init( } case COMMON_SPECULATIVE_TYPE_MTP: { impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt + /* .ctx_tgt = */ ctx_tgt, + /* .mtp_cparams = */ params.cparams_dft )); break; } @@ -1166,6 +1188,33 @@ void common_speculative_print_stats(const common_speculative * spec, double slot // ---------------------------------------------------------------------------- // MTP // ---------------------------------------------------------------------------- + +llama_context * common_speculative_get_mtp_ctx(common_speculative * spec) { + if (!spec) return nullptr; + + for (auto & impl : spec->impls) { + if (impl->type == COMMON_SPECULATIVE_TYPE_MTP) { + auto * mtp_state = dynamic_cast(impl.get()); + if (mtp_state) { + return mtp_state->ctx_mtp; + } + } + } + return nullptr; +} + +void common_speculative_context_shift( + common_speculative * spec, + llama_seq_id seq_id, + llama_pos kv_keep, + llama_pos kv_discard, + llama_pos kv_past) { + if (auto * ctx_mtp = common_speculative_get_mtp_ctx(spec); ctx_mtp != nullptr) { + llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard); + llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard); + } +} + std::vector mtp_speculative_gen_draft( struct common_sampler * smpl, struct llama_context * ctx, @@ -1231,7 +1280,15 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b return; } - LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + llama_seq_id seq_id = batch.seq_id[0][0]; + llama_pos start_pos = batch.pos[0]; + + if (llama_kv_cache_seq_pos_max(ctx, seq_id) >= start_pos) { + llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1); + } + + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens from pos %d...\n", + is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, (int)start_pos); llama_batch mtp_batch = batch; if (is_prompt_warmup) { diff --git a/common/speculative.h b/common/speculative.h index 45882e30..2b3fc0c0 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -41,6 +41,17 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec, double slot_tps = 0.0, int n_decoded = 0, int n_past = 0, common_params_speculative * active_params = nullptr); +// get the MTP context from the speculative object (nullptr if not MTP type) +llama_context * common_speculative_get_mtp_ctx(common_speculative * spec); + +// Context shift for MTP to match how server handle main model +void common_speculative_context_shift( + common_speculative * spec, + llama_seq_id seq_id, + llama_pos kv_keep, + llama_pos kv_discard, + llama_pos kv_past); + // Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture. std::vector mtp_speculative_gen_draft( struct common_sampler * smpl, diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 892f532d..dcad17ef 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -214,9 +214,15 @@ void server_context::init() { params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; params_base.pooling_type = LLAMA_POOLING_TYPE_NONE; + params_base.speculative.cparams_dft = common_context_params_to_llama(params_base); + params_base.speculative.cparams_dft.mtp = true; + params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP; + params_base.speculative.cparams_dft.embeddings = true; + slot.has_mtp = true; slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; slot.params.speculative.n_min = 0; + slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft; slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); @@ -2622,6 +2628,9 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard); llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard); + if (slot.has_mtp && slot.spec) { + common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past); + } if (slot.params.cache_prompt) { slot.cache_tokens.discard_n_tokens(n_keep, n_discard); } @@ -2838,10 +2847,12 @@ void server_context::add_sampled_tokens() { auto & params_spec = slot.params.speculative; if (slot.has_mtp) { + llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); + llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx; if (!slot.mtp_hidden_state.empty()) { const int n_embd = llama_model_n_embd(llama_get_model(ctx)); const int n_hidden = slot.mtp_hidden_state.size() / n_embd; - llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd); + llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd); } else { LOG_ERROR("MTP hidden state is empty during speculation", {}); const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1); @@ -2849,7 +2860,7 @@ void server_context::add_sampled_tokens() { const int n_embd = llama_model_n_embd(llama_get_model(ctx)); slot.mtp_hidden_state.resize(n_embd); memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float)); - llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data()); + llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data()); } } } @@ -3415,6 +3426,9 @@ void server_context::speculative_decoding_accept() { const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted); if (slot.has_mtp) { + llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); + llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; + const int n_embd = llama_model_n_embd(llama_get_model(ctx)); if (!ids.empty()) { const float* emb = llama_get_embeddings(ctx); @@ -3430,10 +3444,10 @@ void server_context::speculative_decoding_accept() { } } - llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data()); + llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data()); int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1); - mtp_accept_tokens(ctx, ids, n_past_base, slot.id); + mtp_accept_tokens(mtp_target, ids, n_past_base, slot.id); } slot.i_batch_dft.clear(); @@ -3933,8 +3947,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.i_batch = -1; } if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) { - llama_set_draft_input_hidden_state(ctx, batch_mtp_hidden_state.data()); - mtp_update_kv_cache(ctx, batch_view, true); + llama_context * mtp_ctx = nullptr; + for (auto & slot : slots) { + if (slot.spec && slot.has_mtp) { + llama_context * mc = common_speculative_get_mtp_ctx(slot.spec); + if (mc) { mtp_ctx = mc; break; } + } + } + llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; + llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data()); + mtp_update_kv_cache(mtp_target, batch_view, true); } // speculative decoding - main model sample and accept diff --git a/src/llama-context.h b/src/llama-context.h index 9cdf8df0..be563122 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -47,9 +47,6 @@ struct llama_kv_cache { uint32_t size = 0; uint32_t used = 0; // used cells (i.e. at least one seq_id) - // Track's main model's head position for MTP KV cache operations - uint32_t mtp_kv_head_hint = 0; - // computed before each graph build uint32_t n = 0; diff --git a/src/llama.cpp b/src/llama.cpp index ba05fa5f..2c3e19eb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -850,7 +850,16 @@ static bool llama_kv_cache_init( } int n_mla = 0; + const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers; for (int i = 0; i < (int) n_layer; i++) { + // For MTP-only context, skip KV allocation for non-MTP layers + if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) { + cache.k_l.push_back(nullptr); + if (!is_mla_attn || !cparams.mla_attn || (cparams.mla_attn == 1 && !cparams.flash_attn)) { + cache.v_l.push_back(nullptr); + } + continue; + } const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i); const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i); const uint32_t n_head_kv = hparams.n_head_kv(i); @@ -1066,8 +1075,7 @@ static bool llama_kv_cache_init( // to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_batch & batch, - enum llama_mtp_op_type op_type) { + const struct llama_batch & batch) { const uint32_t n_tokens = batch.n_tokens; if (cache.recurrent) { @@ -1118,51 +1126,6 @@ static bool llama_kv_cache_find_slot( } // otherwise, one cell per token. - bool is_mtp_special_op = (op_type == MTP_OP_WARMUP || - op_type == MTP_OP_UPDATE_ACCEPTED); - if (is_mtp_special_op) { - const llama_pos target_pos = batch.pos[0]; - const llama_seq_id target_seq = batch.seq_id[0][0]; - - bool found = false; - - if (cache.mtp_kv_head_hint < cache.size && - cache.cells[cache.mtp_kv_head_hint].pos == target_pos && - cache.cells[cache.mtp_kv_head_hint].has_seq_id(target_seq)) { - cache.head = cache.mtp_kv_head_hint; - found = true; - } - else if (cache.head < cache.size && - cache.cells[cache.head].pos == target_pos && - cache.cells[cache.head].has_seq_id(target_seq)) { - found = true; - } - else { - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos == target_pos && - cache.cells[i].has_seq_id(target_seq)) { - - cache.head = i; - found = true; - break; - } - } - } - - if (!found) { - LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n", - __func__, target_seq, target_pos); - return false; - } - - if (cache.head + n_tokens > cache.size) { - LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__); - return false; - } - - return true; - } - if (n_tokens > cache.size) { LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); return false; @@ -3922,14 +3885,10 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) { + if (!llama_kv_cache_find_slot(kv_self, u_batch)) { return 1; } - if (cparams.mtp_op_type == MTP_OP_NONE) { - kv_self.mtp_kv_head_hint = kv_self.head; - } - if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -6842,7 +6801,7 @@ struct llama_data_read { batch.n_seq_id[i] = 1; batch.seq_id[i][0] = dest_seq_id; } - if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) { + if (!llama_kv_cache_find_slot(kv_self, batch)) { llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false;