From 73216488446e3127a4e2db5278c5d6d39340712d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 19 Jun 2026 09:00:44 +0200 Subject: [PATCH] Fix Gemma4 MTP compute graph (#1993) * Fix MTP warmup for GLM models * Fix Gemma4 MTP compute graph --- src/graphs/build_gemma4.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/graphs/build_gemma4.cpp b/src/graphs/build_gemma4.cpp index 7e4c655a..0aee1f0d 100644 --- a/src/graphs/build_gemma4.cpp +++ b/src/graphs/build_gemma4.cpp @@ -630,7 +630,8 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + auto freq_factors = is_sliding ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur_rope", il); @@ -664,21 +665,24 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() { cb(cur, "l_out", il); } - ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur); + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1); + cb(cur, "l_out_normed", -1); + + auto mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur); cb(mtp_embd, "result_mtp_embd", -1); ggml_set_output(mtp_embd); ggml_build_forward_expand(gf, mtp_embd); - ggml_tensor * logits; // E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but // not required for correct inference — the full-vocab matmul against the tied output // weight still yields valid per-token logits. - { - logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false); - cb(logits, "result_output", -1); - } + auto logits = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(logits, "result_output", -1); + ggml_build_forward_expand(gf, logits); + return gf; + GGML_UNUSED(n_embd); GGML_UNUSED(n_vocab);