mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
GLM-5 MTP (again) (#1890)
* wip: port MTP architecture Ports the Multi-Token Prediction (MTP) architecture to the older `llama.cpp` codebase used by `ikllama`. Changes include: - Updating `llama_batch` to support `mtp_params`. - Modifying `llama_decode_internal` (and `encode`) to handle MTP operations (Warmup, Update, Draft). - Adding public APIs for MTP state management (`llama_set_draft_input_hidden_state`). - Adapting the embedding extraction logic to skip MTP update passes. * Refactors `server_slot` to support generic speculative decoding (MTP or Draft Model). * core: enable hybrid outputs (logits + embeddings) for MTP support * fix(mtp): correct KV-cache slot finding for updates * fix(mtp): persist hidden states to prevent context corruption during drafting * refactor(mtp): clean unused code * fix(mtp): update server to new functions name * fix(mtp): fix graph and save hidden state * mtp: refactor integration, context params and kv cache search * mtp: fix hidden state extraction and speculative acceptance flow * server: fix MTP warmup for long prompts and reset token buffer * llama: refactor MTP operation state to context parameters * server: fix n_past calculation in MTP acceptance * llama: fix mtp enable flags * speculative: refactor MTP to use common_speculative interface * context: remove unused signatures * clip: fix deprecated enum-enum conversion warning * common: fix format string crash in help message * context: fix mtp activation logic * llamat: always use the extracted embedding * llama: get all embeddings to kv cache * llama: revert logit to not run mtp for not supported arch * llama: allocate all the n_outputs for MTP * wip * server-context: get only the last embedding for hidden state * ggml-backend: fix array of bounds in debug build * server-context: run mt kv update to each prompt batch * revert segmentation fault fixes * glm-mtp(feat): optimize graph embedding and recursive drafting * glm5-mtp(feat): add glm 5 mtp logic * glm-mtp: standardize the MTP graph * glm 5 mtp: apply post-layer cvec * glm 5 mtp: mark head as mandatory * get normed embeddings for glm 5 * Fix GLM5 MTP * GLM5 MTP: just reuse the layer attention implementation * Make MTP work with split mode graph --------- Co-authored-by: samuel <samueloliveira32df@gmail.com>
This commit is contained in:
parent
3bf7e836c2
commit
6eff055a0c
@ -32,6 +32,9 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention(
|
||||
const uint32_t n_embd_head_v = hparams.n_embd_head_v(il);
|
||||
|
||||
auto cache_repl = (const ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
if (!cache_repl) {
|
||||
LLAMA_LOG_ERROR("%s: no cache split for layer %d?\n", __func__, il);
|
||||
}
|
||||
GGML_ASSERT(cache_repl);
|
||||
GGML_ASSERT(cache_repl->n_device == n_device);
|
||||
|
||||
@ -799,6 +802,32 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_rot, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE) {
|
||||
if (model.arch != LLM_ARCH_GLM_DSA || !model.mtp || hparams.nextn_predict_layers == 0) {
|
||||
GGML_ABORT("MTP tail is only wired for GLM_DSA models with NextN layers enabled");
|
||||
}
|
||||
|
||||
ggml_tensor * hidden_states_from_main_model;
|
||||
|
||||
if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
|
||||
} else {
|
||||
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
|
||||
}
|
||||
ggml_set_name(hidden_states_from_main_model, "inp_mtp_states");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
lctx.inp_mtp_states = hidden_states_from_main_model;
|
||||
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
|
||||
cur = build_deepseek2_mtp(mtp_layer, hidden_states_from_main_model, gf, inp_pos, rope_cache);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return gf;
|
||||
}
|
||||
|
||||
int n_active_layers = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_active_layers; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
@ -815,7 +844,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
use_f32_attn_precision, is_lite, pp_opt);
|
||||
}
|
||||
|
||||
if (il == n_active_layers - 1) {
|
||||
if (il == n_active_layers - 1 && !lctx.cparams.mtp) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
@ -914,3 +943,114 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_tensor * llm_build_context::build_deepseek2_mtp(
|
||||
const llama_layer & mtp_layer,
|
||||
struct ggml_tensor * prev_embeddings,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inp_pos,
|
||||
[[maybe_unused]] struct ggml_tensor * rope_cache) {
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_f32_attn_precision = true;
|
||||
#else
|
||||
constexpr bool use_f32_attn_precision = false;
|
||||
#endif
|
||||
|
||||
const int il = hparams.n_layer - 1;
|
||||
|
||||
const uint32_t n_embd_head_k_mtp = hparams.n_embd_head_k(il);
|
||||
|
||||
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
|
||||
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k_mtp));
|
||||
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
||||
|
||||
// Token embedding
|
||||
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
|
||||
if (mtp_embd_weights == nullptr) {
|
||||
mtp_embd_weights = model.tok_embd;
|
||||
}
|
||||
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
|
||||
|
||||
// Normalize and project
|
||||
ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il);
|
||||
ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il);
|
||||
|
||||
if (mtp_layer.nextn.eh_proj == nullptr) {
|
||||
GGML_ABORT("GLM_DSA MTP requires nextn.eh_proj");
|
||||
}
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
|
||||
cb(combined, "mtp_concat", il);
|
||||
ggml_tensor * cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined);
|
||||
|
||||
struct ggml_tensor * inpSA = cur;
|
||||
|
||||
cur = build_deepseek2_layer_attention(gf, il, cur, KQ_mask, inp_pos, nullptr,
|
||||
kq_scale, attn_factor_scaled,
|
||||
use_f32_attn_precision, false, false);
|
||||
|
||||
// Residual + FFN
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "mtp_ffn_inp", il);
|
||||
|
||||
if (inp_out_ids) {
|
||||
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.ffn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MoE FFN (MTP layer is always in the MoE range, not dense)
|
||||
{
|
||||
ggml_tensor * moe_out =
|
||||
llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
mtp_layer.ffn_gate_inp,
|
||||
mtp_layer.ffn_up_exps,
|
||||
mtp_layer.ffn_gate_exps,
|
||||
mtp_layer.ffn_down_exps,
|
||||
mtp_layer.ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
cb, il, gf, false, mtp_layer.ffn_up_gate_exps);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// Shared Expert FFN
|
||||
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur,
|
||||
mtp_layer.ffn_up_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_gate_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "mtp_ffn_out_resid", il);
|
||||
|
||||
// Output head
|
||||
if (mtp_layer.nextn.shared_head_norm == nullptr) {
|
||||
GGML_ABORT("GLM_DSA MTP requires nextn.shared_head_norm");
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// If nextn.shared_head_head is missing, use model.output (Main LM Head)
|
||||
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
|
||||
if (mtp_head_weights == nullptr) {
|
||||
mtp_head_weights = model.output;
|
||||
}
|
||||
cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@ -486,6 +486,14 @@ llm_expert_gating_func_type gating_op,
|
||||
struct ggml_tensor * rope_cache
|
||||
);
|
||||
|
||||
struct ggml_tensor * build_deepseek2_mtp(
|
||||
const struct llama_layer & mtp_layer,
|
||||
struct ggml_tensor * prev_embeddings,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inp_pos,
|
||||
struct ggml_tensor * rope_cache
|
||||
);
|
||||
|
||||
struct ggml_tensor * build_qwen35_mtp(
|
||||
const struct llama_layer & mtp_layer,
|
||||
struct ggml_tensor * prev_embeddings,
|
||||
|
||||
@ -1379,8 +1379,12 @@ void llm_load_hparams(
|
||||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
if (model.mtp) {
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
}
|
||||
else {
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 79: model.type = MODEL_744B_A40B; break;
|
||||
|
||||
@ -2616,8 +2616,11 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) {
|
||||
static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
int flags = 0;
|
||||
if (is_mtp_layer) {
|
||||
flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED;
|
||||
// Skip loading MTP layers if the feature is disabled
|
||||
if (!model.mtp) {
|
||||
if (is_mtp_layer) {
|
||||
flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED;
|
||||
}
|
||||
}
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
@ -2701,14 +2704,14 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) {
|
||||
}
|
||||
|
||||
if (is_mtp_layer) {
|
||||
layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||
layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||
layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||
|
||||
// Optional tensors
|
||||
layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
// Optional tensors
|
||||
layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags);
|
||||
}
|
||||
}
|
||||
return use_mmap_buffer;
|
||||
|
||||
@ -770,7 +770,7 @@ static bool llama_kv_cache_init(
|
||||
const struct llama_hparams & hparams = model.hparams;
|
||||
|
||||
const int64_t n_layer = model.mtp ? hparams.n_layer
|
||||
: hparams.n_layer - hparams.nextn_predict_layers;
|
||||
: hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
@ -818,11 +818,12 @@ static bool llama_kv_cache_init(
|
||||
// count used buffer types
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
if (offload) {
|
||||
const bool qwen_mtp = (model.arch == LLM_ARCH_QWEN35 ||
|
||||
model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0;
|
||||
const int64_t n_mtp_first = n_layer - hparams.nextn_predict_layers;
|
||||
const bool is_mtp = (model.arch == LLM_ARCH_GLM_DSA ||
|
||||
model.arch == LLM_ARCH_QWEN35 ||
|
||||
model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0;
|
||||
const int64_t n_mtp_first = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
for (int64_t i = 0; i < n_layer; ++i) {
|
||||
const bool is_mtp_tail = qwen_mtp && i >= n_mtp_first;
|
||||
const bool is_mtp_tail = is_mtp && i >= n_mtp_first;
|
||||
if ((split_cache || replicate_mla) && !is_mtp_tail) {
|
||||
buft_layer_count[model.buft_layer[i].buft_matrix]++;
|
||||
if (model.buft_layer[i].buft != model.buft_layer[i].buft_matrix) {
|
||||
@ -897,7 +898,8 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
|
||||
int n_mla = 0;
|
||||
const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers;
|
||||
int n_kv_active_layers = 0;
|
||||
const int64_t n_mtp_first_layer = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
// For MTP-only context, skip KV allocation for non-MTP layers
|
||||
if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) {
|
||||
@ -907,13 +909,15 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
n_kv_active_layers++;
|
||||
const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i);
|
||||
const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i);
|
||||
const uint32_t n_head_kv = hparams.n_head_kv(i);
|
||||
const uint32_t n_embd_head_k= hparams.n_embd_head_k(i);
|
||||
|
||||
const bool is_mtp_tail_layer = (model.arch == LLM_ARCH_QWEN35 ||
|
||||
model.arch == LLM_ARCH_QWEN35MOE) &&
|
||||
model.arch == LLM_ARCH_QWEN35MOE ||
|
||||
model.arch == LLM_ARCH_GLM_DSA) &&
|
||||
hparams.nextn_predict_layers > 0 && i >= (int)n_mtp_first_layer;
|
||||
//struct ggml_context * ctx = split_cache && !qnext_recurrent ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
struct ggml_context * ctx = ((split_cache || replicate_mla) && !is_mtp_tail_layer) ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
@ -1083,8 +1087,8 @@ static bool llama_kv_cache_init(
|
||||
cache.v_l.push_back(v);
|
||||
}
|
||||
}
|
||||
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
|
||||
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
|
||||
if (is_mla_attn && cparams.mla_attn && n_mla < n_kv_active_layers && n_mla > 0) {
|
||||
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d active KV layers having MLA enabled\n", __func__, n_mla, n_kv_active_layers);
|
||||
LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@ -6820,7 +6824,8 @@ struct llama_context * llama_init_from_model(
|
||||
|
||||
if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
|
||||
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&
|
||||
model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) {
|
||||
model->arch != LLM_ARCH_GEMMA4_MTP && model->arch != LLM_ARCH_GLM_DSA &&
|
||||
cparams.mtp != 0) {
|
||||
cparams.mtp = 0;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user