Add llama_context to MTP (#1601)

* wip: separate llama_context for MTP with graph reuse

* wip: fix KV cache desync with separate MTP context

* refactor: remove dead mtp logic code, encapsulate KV mirroring

* mtp-context: derive args directly from the main model's context

* mtp: fix kv cache positions

* clean small comments

* minor refactor for context shift
This commit is contained in:
Samuel Oliveira Alves 2026-04-09 10:33:56 -03:00 committed by GitHub
parent 9b5785ad6b
commit 557b674f63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 67 deletions

View File

@ -148,11 +148,13 @@ struct common_speculative_state {
struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
llama_context * ctx_mtp = nullptr;
common_sampler * smpl;
common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt)
llama_context * ctx_tgt,
const llama_context_params & mtp_cparams)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
{
@ -161,10 +163,21 @@ struct common_speculative_state_mtp : public common_speculative_state {
llama_sampler_type::DIST,
};
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
const llama_model * model = llama_get_model(ctx_tgt);
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), mtp_cparams);
if (ctx_mtp) {
LOG_INF("%s: created MTP context (n_ctx=%d)\n", __func__, llama_n_ctx(ctx_mtp));
} else {
LOG_ERR("%s: failed to create MTP context, falling back to shared context\n", __func__);
}
}
~common_speculative_state_mtp() override {
common_sampler_free(smpl);
if (ctx_mtp) {
llama_free(ctx_mtp);
}
}
void begin(const llama_tokens & prompt) override {
@ -178,12 +191,20 @@ struct common_speculative_state_mtp : public common_speculative_state {
llama_tokens & result) override {
int32_t n_past = (int32_t)prompt_tgt.size();
llama_seq_id seq_id = 0;
if (ctx_mtp) {
llama_pos mtp_pos_max = llama_kv_cache_seq_pos_max(ctx_mtp, seq_id);
if (mtp_pos_max >= n_past) {
llama_kv_cache_seq_rm(ctx_mtp, seq_id, n_past, -1);
}
}
llama_context * ctx = ctx_mtp ? ctx_mtp : ctx_tgt;
result = mtp_speculative_gen_draft(
smpl,
ctx_tgt,
ctx,
params.n_max,
params.p_min,
id_last,
@ -954,7 +975,8 @@ common_speculative * common_speculative_init(
}
case COMMON_SPECULATIVE_TYPE_MTP: {
impls.push_back(std::make_unique<common_speculative_state_mtp>(config.type,
/* .ctx_tgt = */ ctx_tgt
/* .ctx_tgt = */ ctx_tgt,
/* .mtp_cparams = */ params.cparams_dft
));
break;
}
@ -1166,6 +1188,33 @@ void common_speculative_print_stats(const common_speculative * spec, double slot
// ----------------------------------------------------------------------------
// MTP
// ----------------------------------------------------------------------------
llama_context * common_speculative_get_mtp_ctx(common_speculative * spec) {
if (!spec) return nullptr;
for (auto & impl : spec->impls) {
if (impl->type == COMMON_SPECULATIVE_TYPE_MTP) {
auto * mtp_state = dynamic_cast<common_speculative_state_mtp *>(impl.get());
if (mtp_state) {
return mtp_state->ctx_mtp;
}
}
}
return nullptr;
}
void common_speculative_context_shift(
common_speculative * spec,
llama_seq_id seq_id,
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past) {
if (auto * ctx_mtp = common_speculative_get_mtp_ctx(spec); ctx_mtp != nullptr) {
llama_kv_cache_seq_rm (ctx_mtp, seq_id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_add(ctx_mtp, seq_id, kv_keep + kv_discard, kv_past, -kv_discard);
}
}
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
@ -1231,7 +1280,15 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
return;
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
llama_seq_id seq_id = batch.seq_id[0][0];
llama_pos start_pos = batch.pos[0];
if (llama_kv_cache_seq_pos_max(ctx, seq_id) >= start_pos) {
llama_kv_cache_seq_rm(ctx, seq_id, start_pos, -1);
}
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens from pos %d...\n",
is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens, (int)start_pos);
llama_batch mtp_batch = batch;
if (is_prompt_warmup) {

View File

@ -41,6 +41,17 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec, double slot_tps = 0.0, int n_decoded = 0, int n_past = 0, common_params_speculative * active_params = nullptr);
// get the MTP context from the speculative object (nullptr if not MTP type)
llama_context * common_speculative_get_mtp_ctx(common_speculative * spec);
// Context shift for MTP to match how server handle main model
void common_speculative_context_shift(
common_speculative * spec,
llama_seq_id seq_id,
llama_pos kv_keep,
llama_pos kv_discard,
llama_pos kv_past);
// Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
std::vector<llama_token> mtp_speculative_gen_draft(
struct common_sampler * smpl,

View File

@ -214,9 +214,15 @@ void server_context::init() {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
params_base.speculative.cparams_dft.mtp = true;
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params_base.speculative.cparams_dft.embeddings = true;
slot.has_mtp = true;
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.params.speculative.n_min = 0;
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
@ -2622,6 +2628,9 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
if (slot.has_mtp && slot.spec) {
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
}
if (slot.params.cache_prompt) {
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
}
@ -2838,10 +2847,12 @@ void server_context::add_sampled_tokens() {
auto & params_spec = slot.params.speculative;
if (slot.has_mtp) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
if (!slot.mtp_hidden_state.empty()) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
} else {
LOG_ERROR("MTP hidden state is empty during speculation", {});
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
@ -2849,7 +2860,7 @@ void server_context::add_sampled_tokens() {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
}
}
}
@ -3415,6 +3426,9 @@ void server_context::speculative_decoding_accept() {
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
if (slot.has_mtp) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
if (!ids.empty()) {
const float* emb = llama_get_embeddings(ctx);
@ -3430,10 +3444,10 @@ void server_context::speculative_decoding_accept() {
}
}
llama_set_draft_input_hidden_state(ctx, slot.mtp_hidden_state.data());
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
mtp_accept_tokens(ctx, ids, n_past_base, slot.id);
mtp_accept_tokens(mtp_target, ids, n_past_base, slot.id);
}
slot.i_batch_dft.clear();
@ -3933,8 +3947,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.i_batch = -1;
}
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
llama_set_draft_input_hidden_state(ctx, batch_mtp_hidden_state.data());
mtp_update_kv_cache(ctx, batch_view, true);
llama_context * mtp_ctx = nullptr;
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp) {
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
if (mc) { mtp_ctx = mc; break; }
}
}
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
mtp_update_kv_cache(mtp_target, batch_view, true);
}
// speculative decoding - main model sample and accept

View File

@ -47,9 +47,6 @@ struct llama_kv_cache {
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// Track's main model's head position for MTP KV cache operations
uint32_t mtp_kv_head_hint = 0;
// computed before each graph build
uint32_t n = 0;

View File

@ -850,7 +850,16 @@ static bool llama_kv_cache_init(
}
int n_mla = 0;
const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < (int) n_layer; i++) {
// For MTP-only context, skip KV allocation for non-MTP layers
if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) {
cache.k_l.push_back(nullptr);
if (!is_mla_attn || !cparams.mla_attn || (cparams.mla_attn == 1 && !cparams.flash_attn)) {
cache.v_l.push_back(nullptr);
}
continue;
}
const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i);
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
@ -1066,8 +1075,7 @@ static bool llama_kv_cache_init(
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch,
enum llama_mtp_op_type op_type) {
const struct llama_batch & batch) {
const uint32_t n_tokens = batch.n_tokens;
if (cache.recurrent) {
@ -1118,51 +1126,6 @@ static bool llama_kv_cache_find_slot(
}
// otherwise, one cell per token.
bool is_mtp_special_op = (op_type == MTP_OP_WARMUP ||
op_type == MTP_OP_UPDATE_ACCEPTED);
if (is_mtp_special_op) {
const llama_pos target_pos = batch.pos[0];
const llama_seq_id target_seq = batch.seq_id[0][0];
bool found = false;
if (cache.mtp_kv_head_hint < cache.size &&
cache.cells[cache.mtp_kv_head_hint].pos == target_pos &&
cache.cells[cache.mtp_kv_head_hint].has_seq_id(target_seq)) {
cache.head = cache.mtp_kv_head_hint;
found = true;
}
else if (cache.head < cache.size &&
cache.cells[cache.head].pos == target_pos &&
cache.cells[cache.head].has_seq_id(target_seq)) {
found = true;
}
else {
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos == target_pos &&
cache.cells[i].has_seq_id(target_seq)) {
cache.head = i;
found = true;
break;
}
}
}
if (!found) {
LLAMA_LOG_ERROR("%s: MTP Update failed - slot for seq %d pos %d not found\n",
__func__, target_seq, target_pos);
return false;
}
if (cache.head + n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: MTP Update out of bounds\n", __func__);
return false;
}
return true;
}
if (n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
return false;
@ -3922,14 +3885,10 @@ static int llama_decode_internal(
kv_self.head = 0;
}
if (!llama_kv_cache_find_slot(kv_self, u_batch, cparams.mtp_op_type)) {
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
return 1;
}
if (cparams.mtp_op_type == MTP_OP_NONE) {
kv_self.mtp_kv_head_hint = kv_self.head;
}
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
@ -6842,7 +6801,7 @@ struct llama_data_read {
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = dest_seq_id;
}
if (!llama_kv_cache_find_slot(kv_self, batch, ctx->cparams.mtp_op_type)) {
if (!llama_kv_cache_find_slot(kv_self, batch)) {
llama_batch_free(batch);
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;