Fix Gemma4 MTP compute graph

This commit is contained in:
Kawrakow 2026-06-18 15:51:22 +00:00
parent 2c1dc8781b
commit 67b0b22760

View File

@ -630,7 +630,8 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il); cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, auto freq_factors = is_sliding ? nullptr : model.layers[il].rope_freqs;
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_rope", il); cb(Qcur, "Qcur_rope", il);
@ -664,21 +665,24 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
cb(cur, "l_out", il); cb(cur, "l_out", il);
} }
ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur); cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
cb(cur, "l_out_normed", -1);
auto mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
cb(mtp_embd, "result_mtp_embd", -1); cb(mtp_embd, "result_mtp_embd", -1);
ggml_set_output(mtp_embd); ggml_set_output(mtp_embd);
ggml_build_forward_expand(gf, mtp_embd); ggml_build_forward_expand(gf, mtp_embd);
ggml_tensor * logits;
// E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but // E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but
// not required for correct inference — the full-vocab matmul against the tied output // not required for correct inference — the full-vocab matmul against the tied output
// weight still yields valid per-token logits. // weight still yields valid per-token logits.
{ auto logits = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.output, cur);
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false); cb(logits, "result_output", -1);
cb(logits, "result_output", -1);
}
ggml_build_forward_expand(gf, logits); ggml_build_forward_expand(gf, logits);
return gf;
GGML_UNUSED(n_embd); GGML_UNUSED(n_embd);
GGML_UNUSED(n_vocab); GGML_UNUSED(n_vocab);