From a407b9ca3dddc0b83eca275d5d2ff3fe7cfa8c4f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 18 May 2026 07:26:17 +0300 Subject: [PATCH] Fix Qwen3.6-MoE low MTP acceptance rate (#1815) * Fix Qwen3.6-MoE low MTP acceptance rate * Fix Gemma4 MTP --- src/graphs/build_gemma4.cpp | 2 +- src/llama-build-context.cpp | 17 ++++++++++++----- src/llama-build-context.h | 2 +- src/llama.cpp | 9 ++++++--- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/graphs/build_gemma4.cpp b/src/graphs/build_gemma4.cpp index 249136f9..b4ff6e31 100644 --- a/src/graphs/build_gemma4.cpp +++ b/src/graphs/build_gemma4.cpp @@ -663,7 +663,7 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() { // not required for correct inference — the full-vocab matmul against the tied output // weight still yields valid per-token logits. { - logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb); + logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false); cb(logits, "result_output", -1); } ggml_build_forward_expand(gf, logits); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index fca8c4b2..f4d0f61e 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -438,9 +438,9 @@ ggml_cgraph * llm_build_context::append_pooling(struct ggml_cgraph * gf) { for (int i = gf->n_nodes - 1; i >= 0; --i) { inp = gf->nodes[i]; - if (strcmp(inp->name, "result_norm") == 0 || - strcmp(inp->name, "result_embd") == 0 || - strcmp(inp->name, "output_normed") == 0) { + if (strcmp(inp->name, "result_norm") == 0 || + strcmp(inp->name, "result_embd") == 0 || + strcmp(inp->name, "output_normed") == 0) { break; } inp = nullptr; @@ -2048,25 +2048,30 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context } ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, - ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) { + ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name) { // lm_head if (output->extra) { auto split_output = (ggml_split_tensor_t *)output->extra; auto split_output_norm = output_norm && output_norm->extra ? (ggml_split_tensor_t *)output_norm->extra : nullptr; std::vector o; o.reserve(split_output->n_device); + ggml_tensor * last_norm = nullptr; for (int id = 0; id < split_output->n_device; ++id) { auto split = split_output->splits[id]; if (!split) continue; if (output_norm) { auto the_norm = split_output_norm ? split_output_norm->splits[id] : output_norm; auto cur_normed = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, the_norm, NULL, LLM_NORM_RMS, cb, -1); + last_norm = cur_normed; cb(cur_normed, "result_norm", 1000*(id+1)); o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur_normed)); } else { o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur)); } cb(o.back(), "output", id); + if (add_normed_name && last_norm) { + cb(last_norm, "result_norm", -1); + } } GGML_ASSERT(!o.empty()); if (o.size() == 1) { @@ -2090,7 +2095,9 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context } if (output_norm) { cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); + if (add_normed_name) { + cb(cur, "result_norm", -1); + } } cur = llm_build_context::llm_build_lora_mm(lctx, ctx, output, cur); } diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 361f9710..3ef81d01 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -448,7 +448,7 @@ llm_expert_gating_func_type gating_op, static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, const llm_build_cb & cb); static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, - ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb); + ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name = true); static ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * the_norm, const llama_hparams & hparams, const llm_build_cb & cb, int id, int il_cb, bool is_norm); diff --git a/src/llama.cpp b/src/llama.cpp index 2850fa52..879028d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1612,7 +1612,7 @@ static void restore_recurrent_cache_tensors(int step, ggml_backend_sched_t sched size_t ssm_bytes, size_t conv_bytes, ggml_tensor * s_l, ggml_tensor * per_step_ssm, ggml_tensor * per_step_conv, std::unordered_set & backends_to_sync) { - auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l); + auto dst_backend = ggml_backend_sched_get_backend(sched, ggml_backend_sched_get_backend_idx(sched, s_l->buffer)); auto dst = *s_l; dst.ne[0] = ssm_bytes/sizeof(float); dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes + conv_bytes; @@ -4766,8 +4766,10 @@ static int llama_decode_internal( } else { const bool has_mtp = llama_context_has_mtp_outputs(lctx); - const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 || - lctx.model.arch == LLM_ARCH_QWEN35MOE || lctx.model.arch == LLM_ARCH_GEMMA4 || lctx.model.arch == LLM_ARCH_GEMMA4_MTP); + const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 || + lctx.model.arch == LLM_ARCH_QWEN35MOE || + lctx.model.arch == LLM_ARCH_GEMMA4 || + lctx.model.arch == LLM_ARCH_GEMMA4_MTP); if (cparams.embeddings || has_mtp) { for (int i = gf->n_nodes - 1; i >= 0; --i) { if (use_raw_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) { @@ -4781,6 +4783,7 @@ static int llama_decode_internal( } if (strcmp(gf->nodes[i]->name, "result_norm") == 0) { embd = gf->nodes[i]; + break; } } }