mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix Gemma4 MTP compute graph (#1993)
* Fix MTP warmup for GLM models * Fix Gemma4 MTP compute graph
This commit is contained in:
parent
0d59973e4a
commit
7321648844
@ -630,7 +630,8 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
auto freq_factors = is_sliding ? nullptr : model.layers[il].rope_freqs;
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
@ -664,21 +665,24 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
|
||||
cb(cur, "l_out", il);
|
||||
}
|
||||
|
||||
ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "l_out_normed", -1);
|
||||
|
||||
auto mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
|
||||
cb(mtp_embd, "result_mtp_embd", -1);
|
||||
ggml_set_output(mtp_embd);
|
||||
ggml_build_forward_expand(gf, mtp_embd);
|
||||
|
||||
ggml_tensor * logits;
|
||||
// E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but
|
||||
// not required for correct inference — the full-vocab matmul against the tied output
|
||||
// weight still yields valid per-token logits.
|
||||
{
|
||||
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false);
|
||||
cb(logits, "result_output", -1);
|
||||
}
|
||||
auto logits = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(logits, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
|
||||
return gf;
|
||||
|
||||
GGML_UNUSED(n_embd);
|
||||
GGML_UNUSED(n_vocab);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user