diff --git a/src/graphs/build_deepseek2.cpp b/src/graphs/build_deepseek2.cpp index 0c4a2f70..1d78a826 100644 --- a/src/graphs/build_deepseek2.cpp +++ b/src/graphs/build_deepseek2.cpp @@ -32,6 +32,9 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention( const uint32_t n_embd_head_v = hparams.n_embd_head_v(il); auto cache_repl = (const ggml_split_tensor_t *)kv_self.k_l[il]->extra; + if (!cache_repl) { + LLAMA_LOG_ERROR("%s: no cache split for layer %d?\n", __func__, il); + } GGML_ASSERT(cache_repl); GGML_ASSERT(cache_repl->n_device == n_device); @@ -799,6 +802,32 @@ ggml_cgraph * llm_build_context::build_deepseek2() { ggml_rope_cache(ctx0, inp_pos, nullptr, n_rot, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) : nullptr; + if (cparams.mtp_op_type != MTP_OP_NONE) { + if (model.arch != LLM_ARCH_GLM_DSA || !model.mtp || hparams.nextn_predict_layers == 0) { + GGML_ABORT("MTP tail is only wired for GLM_DSA models with NextN layers enabled"); + } + + ggml_tensor * hidden_states_from_main_model; + + if (cparams.mtp_op_type == MTP_OP_WARMUP || cparams.mtp_op_type == MTP_OP_UPDATE_ACCEPTED) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + } else { + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + } + ggml_set_name(hidden_states_from_main_model, "inp_mtp_states"); + ggml_set_input(hidden_states_from_main_model); + + lctx.inp_mtp_states = hidden_states_from_main_model; + + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + + cur = build_deepseek2_mtp(mtp_layer, hidden_states_from_main_model, gf, inp_pos, rope_cache); + + ggml_build_forward_expand(gf, cur); + return gf; + } + int n_active_layers = hparams.n_layer - hparams.nextn_predict_layers; for (int il = 0; il < n_active_layers; ++il) { struct ggml_tensor * inpSA = inpL; @@ -815,7 +844,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { use_f32_attn_precision, is_lite, pp_opt); } - if (il == n_active_layers - 1) { + if (il == n_active_layers - 1 && !lctx.cparams.mtp) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); n_tokens = n_outputs; @@ -914,3 +943,114 @@ ggml_cgraph * llm_build_context::build_deepseek2() { return gf; } + +struct ggml_tensor * llm_build_context::build_deepseek2_mtp( + const llama_layer & mtp_layer, + struct ggml_tensor * prev_embeddings, + struct ggml_cgraph * gf, + struct ggml_tensor * inp_pos, + [[maybe_unused]] struct ggml_tensor * rope_cache) { +#ifdef GGML_USE_VULKAN + constexpr bool use_f32_attn_precision = true; +#else + constexpr bool use_f32_attn_precision = false; +#endif + + const int il = hparams.n_layer - 1; + + const uint32_t n_embd_head_k_mtp = hparams.n_embd_head_k(il); + + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k_mtp)); + const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; + + // Token embedding + ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens; + if (mtp_embd_weights == nullptr) { + mtp_embd_weights = model.tok_embd; + } + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights); + + // Normalize and project + ggml_tensor * token_emb_norm = llm_build_norm(ctx0, token_emb, hparams, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, cb, il); + ggml_tensor * hidden_state_norm = llm_build_norm(ctx0, prev_embeddings, hparams, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, cb, il); + + if (mtp_layer.nextn.eh_proj == nullptr) { + GGML_ABORT("GLM_DSA MTP requires nextn.eh_proj"); + } + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + cb(combined, "mtp_concat", il); + ggml_tensor * cur = llm_build_lora_mm(lctx, ctx0, mtp_layer.nextn.eh_proj, combined); + + struct ggml_tensor * inpSA = cur; + + cur = build_deepseek2_layer_attention(gf, il, cur, KQ_mask, inp_pos, nullptr, + kq_scale, attn_factor_scaled, + use_f32_attn_precision, false, false); + + // Residual + FFN + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "mtp_ffn_inp", il); + + if (inp_out_ids) { + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + cur = llm_build_norm(ctx0, ffn_inp, hparams, mtp_layer.ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // MoE FFN (MTP layer is always in the MoE range, not dense) + { + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, lctx, cur, + mtp_layer.ffn_gate_inp, + mtp_layer.ffn_up_exps, + mtp_layer.ffn_gate_exps, + mtp_layer.ffn_down_exps, + mtp_layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (enum llm_expert_gating_func_type) hparams.expert_gating_func, + cb, il, gf, false, mtp_layer.ffn_up_gate_exps); + cb(moe_out, "ffn_moe_out", il); + + // Shared Expert FFN + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur, + mtp_layer.ffn_up_shexp, NULL, NULL, + mtp_layer.ffn_gate_shexp, NULL, NULL, + mtp_layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "mtp_ffn_out_resid", il); + + // Output head + if (mtp_layer.nextn.shared_head_norm == nullptr) { + GGML_ABORT("GLM_DSA MTP requires nextn.shared_head_norm"); + } + + cur = llm_build_norm(ctx0, cur, hparams, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "result_norm", -1); + + // If nextn.shared_head_head is missing, use model.output (Main LM Head) + ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head; + if (mtp_head_weights == nullptr) { + mtp_head_weights = model.output; + } + cur = llm_build_lora_mm(lctx, ctx0, mtp_head_weights, cur); + cb(cur, "result_output", -1); + + return cur; +} diff --git a/src/llama-build-context.h b/src/llama-build-context.h index ab6129ca..73490c3a 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -486,6 +486,14 @@ llm_expert_gating_func_type gating_op, struct ggml_tensor * rope_cache ); + struct ggml_tensor * build_deepseek2_mtp( + const struct llama_layer & mtp_layer, + struct ggml_tensor * prev_embeddings, + struct ggml_cgraph * gf, + struct ggml_tensor * inp_pos, + struct ggml_tensor * rope_cache + ); + struct ggml_tensor * build_qwen35_mtp( const struct llama_layer & mtp_layer, struct ggml_tensor * prev_embeddings, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 21c7a415..f3be1c11 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1379,8 +1379,12 @@ void llm_load_hparams( // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + if (model.mtp) { + hparams.n_layer_kv_from_start = hparams.n_layer; + } + else { + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + } switch (hparams.n_layer) { case 79: model.type = MODEL_744B_A40B; break; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 5fbfd010..b3f1ff06 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2616,8 +2616,11 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) { static_cast(i) >= n_layer - hparams.nextn_predict_layers; int flags = 0; - if (is_mtp_layer) { - flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED; + // Skip loading MTP layers if the feature is disabled + if (!model.mtp) { + if (is_mtp_layer) { + flags |= llama_model_loader::TENSOR_SKIP | llama_model_loader::TENSOR_NOT_REQUIRED; + } } ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); @@ -2701,14 +2704,14 @@ bool create_tensors_helper::create_glm_dsa_tensors(const LLM_TN & tn) { } if (is_mtp_layer) { - layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - // Optional tensors - layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | llama_model_loader::TENSOR_NOT_REQUIRED); + // Optional tensors + layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | llama_model_loader::TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); } } return use_mmap_buffer; diff --git a/src/llama.cpp b/src/llama.cpp index 67e05a47..a14259cc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -770,7 +770,7 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams = model.hparams; const int64_t n_layer = model.mtp ? hparams.n_layer - : hparams.n_layer - hparams.nextn_predict_layers; + : hparams.n_layer - hparams.nextn_predict_layers; cache.has_shift = false; @@ -818,11 +818,12 @@ static bool llama_kv_cache_init( // count used buffer types std::map buft_layer_count; if (offload) { - const bool qwen_mtp = (model.arch == LLM_ARCH_QWEN35 || - model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0; - const int64_t n_mtp_first = n_layer - hparams.nextn_predict_layers; + const bool is_mtp = (model.arch == LLM_ARCH_GLM_DSA || + model.arch == LLM_ARCH_QWEN35 || + model.arch == LLM_ARCH_QWEN35MOE) && hparams.nextn_predict_layers > 0; + const int64_t n_mtp_first = hparams.n_layer - hparams.nextn_predict_layers; for (int64_t i = 0; i < n_layer; ++i) { - const bool is_mtp_tail = qwen_mtp && i >= n_mtp_first; + const bool is_mtp_tail = is_mtp && i >= n_mtp_first; if ((split_cache || replicate_mla) && !is_mtp_tail) { buft_layer_count[model.buft_layer[i].buft_matrix]++; if (model.buft_layer[i].buft != model.buft_layer[i].buft_matrix) { @@ -897,7 +898,8 @@ static bool llama_kv_cache_init( } int n_mla = 0; - const int64_t n_mtp_first_layer = n_layer - hparams.nextn_predict_layers; + int n_kv_active_layers = 0; + const int64_t n_mtp_first_layer = hparams.n_layer - hparams.nextn_predict_layers; for (int i = 0; i < (int) n_layer; i++) { // For MTP-only context, skip KV allocation for non-MTP layers if (cparams.mtp_op_type != MTP_OP_NONE && i < (int)n_mtp_first_layer) { @@ -907,13 +909,15 @@ static bool llama_kv_cache_init( } continue; } + n_kv_active_layers++; const bool qnext_recurrent = llama_is_recurrent_layer(hparams, i); const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i); const uint32_t n_head_kv = hparams.n_head_kv(i); const uint32_t n_embd_head_k= hparams.n_embd_head_k(i); const bool is_mtp_tail_layer = (model.arch == LLM_ARCH_QWEN35 || - model.arch == LLM_ARCH_QWEN35MOE) && + model.arch == LLM_ARCH_QWEN35MOE || + model.arch == LLM_ARCH_GLM_DSA) && hparams.nextn_predict_layers > 0 && i >= (int)n_mtp_first_layer; //struct ggml_context * ctx = split_cache && !qnext_recurrent ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); struct ggml_context * ctx = ((split_cache || replicate_mla) && !is_mtp_tail_layer) ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); @@ -1083,8 +1087,8 @@ static bool llama_kv_cache_init( cache.v_l.push_back(v); } } - if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) { - LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); + if (is_mla_attn && cparams.mla_attn && n_mla < n_kv_active_layers && n_mla > 0) { + LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d active KV layers having MLA enabled\n", __func__, n_mla, n_kv_active_layers); LLAMA_LOG_ERROR("%s: bailing out\n", __func__); GGML_ABORT("fatal error"); } @@ -6820,7 +6824,8 @@ struct llama_context * llama_init_from_model( if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 && model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 && - model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) { + model->arch != LLM_ARCH_GEMMA4_MTP && model->arch != LLM_ARCH_GLM_DSA && + cparams.mtp != 0) { cparams.mtp = 0; }