GLM-5 MTP (again) (#1890)

* wip: port MTP architecture

Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`.

Changes include:
- Updating `llama_batch` to support `mtp_params`.
- Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft).
- Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`).
- Adapting the embedding extraction logic to skip MTP update passes.

* Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model).

* core: enable hybrid outputs (logits + embeddings) for MTP support

* fix(mtp): correct KV-cache slot finding for updates

* fix(mtp): persist hidden states to prevent context corruption during drafting

* refactor(mtp): clean unused code

* fix(mtp): update server to new functions name

* fix(mtp): fix graph and save hidden state

* mtp: refactor integration, context params and kv cache search

* mtp: fix hidden state extraction and speculative acceptance flow

* server: fix MTP warmup for long prompts and reset token buffer

* llama: refactor MTP operation state to context parameters

* server: fix n_past calculation in MTP acceptance

* llama: fix mtp enable flags

* speculative: refactor MTP to use common_speculative interface

* context: remove unused signatures

* clip: fix deprecated enum-enum conversion warning

* common: fix format string crash in help message

* context: fix mtp activation logic

* llamat: always use the extracted embedding

* llama: get all embeddings to kv cache

* llama: revert logit to not run mtp for not supported arch

* llama: allocate all the n_outputs for MTP

* wip

* server-context: get only the last embedding for hidden state

* ggml-backend: fix array of bounds in debug build

* server-context: run mt kv update to each prompt batch

* revert segmentation fault fixes

* glm-mtp(feat): optimize graph embedding and recursive drafting

* glm5-mtp(feat): add glm 5 mtp logic

* glm-mtp: standardize the MTP graph

* glm 5 mtp: apply post-layer cvec

* glm 5 mtp: mark head as mandatory

* get normed embeddings for glm 5

* Fix GLM5 MTP

* GLM5 MTP: just reuse the layer attention implementation

* Make MTP work with split mode graph

---------

Co-authored-by: samuel <samueloliveira32df@gmail.com>
This commit is contained in:
Kawrakow 2026-05-28 18:14:12 +03:00 committed by GitHub
parent 3bf7e836c2
commit 6eff055a0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 182 additions and 22 deletions

View File

@ -32,6 +32,9 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention(
const uint32_t n_embd_head_v = hparams.n_embd_head_v(il);
auto cache_repl = (const ggml_split_tensor_t *)kv_self.k_l[il]->extra;
if (!cache_repl) {
LLAMA_LOG_ERROR("%s: no cache split for layer %d?\n", __func__, il);
}
GGML_ASSERT(cache_repl);
GGML_ASSERT(cache_repl->n_device == n_device);
@ -799,6 +802,32 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
ggml_rope_cache(ctx0, inp_pos, nullptr, n_rot, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
if (cparams.mtp_op_type != MTP_OP_NONE) {
if (model.arch != LLM_ARCH_GLM_DSA || !model.mtp || hparams.nextn_predict_layers == 0) {
GGML_ABORT("MTP tail is only wired for GLM_DSA models with NextN layers enabled");
}
ggml_tensor * hidden_states_from_main_model;
if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
} else {
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
}
ggml_set_name(hidden_states_from_main_model, "inp_mtp_states");
ggml_set_input(hidden_states_from_main_model);
lctx.inp_mtp_states = hidden_states_from_main_model;
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
cur = build_deepseek2_mtp(mtp_layer, hidden_states_from_main_model, gf, inp_pos, rope_cache);
ggml_build_forward_expand(gf, cur);
return gf;
}
int n_active_layers = hparams.n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_active_layers; ++il) {
struct ggml_tensor * inpSA = inpL;
@ -815,7 +844,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
use_f32_attn_precision, is_lite, pp_opt);
}
if (il == n_active_layers - 1) {
if (il == n_active_layers - 1 && !lctx.cparams.mtp) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
@ -914,3 +943,114 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
return gf;
}
struct ggml_tensor * llm_build_context::build_deepseek2_mtp(
const llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos,
[[maybe_unused]] struct ggml_tensor * rope_cache) {
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_attn_precision = true;
#else
constexpr bool use_f32_attn_precision = false;
#endif
const int il = hparams.n_layer - 1;
const uint32_t n_embd_head_k_mtp = hparams.n_embd_head_k(il);
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k_mtp));
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
// Token embedding
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
if (mtp_embd_weights == nullptr) {
mtp_embd_weights = model.tok_embd;
}
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
// Normalize and project
ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);
if (mtp_layer.nextn.eh_proj == nullptr) {
GGML_ABORT("GLM_DSA MTP requires nextn.eh_proj");
}
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
cb(combined, "mtp_concat", il);
ggml_tensor * cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);
struct ggml_tensor * inpSA = cur;
cur = build_deepseek2_layer_attention(gf, il, cur, KQ_mask, inp_pos, nullptr,
kq_scale, attn_factor_scaled,
use_f32_attn_precision, false, false);
// Residual + FFN
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "mtp_ffn_inp", il);
if (inp_out_ids) {
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}
cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
// MoE FFN (MTP layer is always in the MoE range, not dense)
{
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, lctx, cur,
mtp_layer.ffn_gate_inp,
mtp_layer.ffn_up_exps,
mtp_layer.ffn_gate_exps,
mtp_layer.ffn_down_exps,
mtp_layer.ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
cb, il, gf, false, mtp_layer.ffn_up_gate_exps);
cb(moe_out, "ffn_moe_out", il);
// Shared Expert FFN
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur,
mtp_layer.ffn_up_shexp, NULL, NULL,
mtp_layer.ffn_gate_shexp, NULL, NULL,
mtp_layer.ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "mtp_ffn_out_resid", il);
// Output head
if (mtp_layer.nextn.shared_head_norm == nullptr) {
GGML_ABORT("GLM_DSA MTP requires nextn.shared_head_norm");
}
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "result_norm", -1);
// If nextn.shared_head_head is missing, use model.output (Main LM Head)
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
if (mtp_head_weights == nullptr) {
mtp_head_weights = model.output;
}
cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);
cb(cur, "result_output", -1);
return cur;
}

View File

@ -486,6 +486,14 @@ llm_expert_gating_func_type gating_op,
struct ggml_tensor * rope_cache
);
struct ggml_tensor * build_deepseek2_mtp(
const struct llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,
struct ggml_cgraph * gf,
struct ggml_tensor * inp_pos,
struct ggml_tensor * rope_cache
);
struct ggml_tensor * build_qwen35_mtp(
const struct llama_layer & mtp_layer,
struct ggml_tensor * prev_embeddings,

View File

@ -1379,8 +1379,12 @@ void llm_load_hparams(
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
if (model.mtp) {
hparams.n_layer_kv_from_start = hparams.n_layer;
}
else {
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
}
switch (hparams.n_layer) {
case 79: model.type = MODEL_744B_A40B; break;

View File

@ -2616,8 +2616,11 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) {
static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers;
int flags = 0;
if (is_mtp_layer) {
flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED;
// Skip loading MTP layers if the feature is disabled
if (!model.mtp) {
if (is_mtp_layer) {
flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED;
}
}
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
@ -2701,14 +2704,14 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) {
}
if (is_mtp_layer) {
layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
// Optional tensors
layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
// Optional tensors
layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags);
}
}
return use_mmap_buffer;

View File

@ -770,7 +770,7 @@ static bool llama_kv_cache_init(
const struct llama_hparams & hparams = model.hparams;
const int64_t n_layer = model.mtp ? hparams.n_layer
: hparams.n_layer - hparams.nextn_predict_layers;
: hparams.n_layer - hparams.nextn_predict_layers;
cache.has_shift = false;
@ -818,11 +818,12 @@ static bool llama_kv_cache_init(
// count used buffer types
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
if (offload) {
const bool qwen_mtp = (model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0;
const int64_t n_mtp_first = n_layer - hparams.nextn_predict_layers;
const bool is_mtp = (model.arch == LLM_ARCH_GLM_DSA ||
model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0;
const int64_t n_mtp_first = hparams.n_layer - hparams.nextn_predict_layers;
for (int64_t i = 0; i < n_layer; ++i) {
const bool is_mtp_tail = qwen_mtp && i >= n_mtp_first;
const bool is_mtp_tail = is_mtp && i >= n_mtp_first;
if ((split_cache || replicate_mla) && !is_mtp_tail) {
buft_layer_count[model.buft_layer[i].buft_matrix]++;
if (model.buft_layer[i].buft != model.buft_layer[i].buft_matrix) {
@ -897,7 +898,8 @@ static bool llama_kv_cache_init(
}
int n_mla = 0;
const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers;
int n_kv_active_layers = 0;
const int64_t n_mtp_first_layer = hparams.n_layer - hparams.nextn_predict_layers;
for (int i = 0; i < (int) n_layer; i++) {
// For MTP-only context, skip KV allocation for non-MTP layers
if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) {
@ -907,13 +909,15 @@ static bool llama_kv_cache_init(
}
continue;
}
n_kv_active_layers++;
const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i);
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k(i);
const bool is_mtp_tail_layer = (model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE) &&
model.arch == LLM_ARCH_QWEN35MOE ||
model.arch == LLM_ARCH_GLM_DSA) &&
hparams.nextn_predict_layers > 0 && i >= (int)n_mtp_first_layer;
//struct ggml_context * ctx = split_cache && !qnext_recurrent ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
struct ggml_context * ctx = ((split_cache || replicate_mla) && !is_mtp_tail_layer) ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
@ -1083,8 +1087,8 @@ static bool llama_kv_cache_init(
cache.v_l.push_back(v);
}
}
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
if (is_mla_attn && cparams.mla_attn && n_mla < n_kv_active_layers && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d active KV layers having MLA enabled\n", __func__, n_mla, n_kv_active_layers);
LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
GGML_ABORT("fatal error");
}
@ -6820,7 +6824,8 @@ struct llama_context * llama_init_from_model(
if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&
model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) {
model->arch != LLM_ARCH_GEMMA4_MTP && model->arch != LLM_ARCH_GLM_DSA &&
cparams.mtp != 0) {
cparams.mtp = 0;
}