Fix Qwen3.6-MoE low MTP acceptance rate (#1815)

* Fix Qwen3.6-MoE low MTP acceptance rate

* Fix Gemma4 MTP
This commit is contained in:
Kawrakow 2026-05-18 07:26:17 +03:00 committed by GitHub
parent c35189d83c
commit a407b9ca3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 10 deletions

View File

@ -663,7 +663,7 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
// not required for correct inference — the full-vocab matmul against the tied output
// weight still yields valid per-token logits.
{
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false);
cb(logits, "result_output", -1);
}
ggml_build_forward_expand(gf, logits);

View File

@ -438,9 +438,9 @@ ggml_cgraph * llm_build_context::append_pooling(struct ggml_cgraph * gf) {
for (int i = gf->n_nodes - 1; i >= 0; --i) {
inp = gf->nodes[i];
if (strcmp(inp->name, "result_norm") == 0 ||
strcmp(inp->name, "result_embd") == 0 ||
strcmp(inp->name, "output_normed") == 0) {
if (strcmp(inp->name, "result_norm") == 0 ||
strcmp(inp->name, "result_embd") == 0 ||
strcmp(inp->name, "output_normed") == 0) {
break;
}
inp = nullptr;
@ -2048,25 +2048,30 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context
}
ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) {
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name) {
// lm_head
if (output->extra) {
auto split_output = (ggml_split_tensor_t *)output->extra;
auto split_output_norm = output_norm && output_norm->extra ? (ggml_split_tensor_t *)output_norm->extra : nullptr;
std::vector<ggml_tensor *> o;
o.reserve(split_output->n_device);
ggml_tensor * last_norm = nullptr;
for (int id = 0; id < split_output->n_device; ++id) {
auto split = split_output->splits[id];
if (!split) continue;
if (output_norm) {
auto the_norm = split_output_norm ? split_output_norm->splits[id] : output_norm;
auto cur_normed = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, the_norm, NULL, LLM_NORM_RMS, cb, -1);
last_norm = cur_normed;
cb(cur_normed, "result_norm", 1000*(id+1));
o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur_normed));
} else {
o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur));
}
cb(o.back(), "output", id);
if (add_normed_name && last_norm) {
cb(last_norm, "result_norm", -1);
}
}
GGML_ASSERT(!o.empty());
if (o.size() == 1) {
@ -2090,7 +2095,9 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context
}
if (output_norm) {
cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
if (add_normed_name) {
cb(cur, "result_norm", -1);
}
}
cur = llm_build_context::llm_build_lora_mm(lctx, ctx, output, cur);
}

View File

@ -448,7 +448,7 @@ llm_expert_gating_func_type gating_op,
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, const llm_build_cb & cb);
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb);
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name = true);
static ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * the_norm, const llama_hparams & hparams,
const llm_build_cb & cb, int id, int il_cb, bool is_norm);

View File

@ -1612,7 +1612,7 @@ static void restore_recurrent_cache_tensors(int step, ggml_backend_sched_t sched
size_t ssm_bytes, size_t conv_bytes,
ggml_tensor * s_l, ggml_tensor * per_step_ssm, ggml_tensor * per_step_conv,
std::unordered_set<ggml_backend_t> & backends_to_sync) {
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l);
auto dst_backend = ggml_backend_sched_get_backend(sched, ggml_backend_sched_get_backend_idx(sched, s_l->buffer));
auto dst = *s_l;
dst.ne[0] = ssm_bytes/sizeof(float);
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes + conv_bytes;
@ -4766,8 +4766,10 @@ static int llama_decode_internal(
}
else {
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
lctx.model.arch == LLM_ARCH_QWEN35MOE || lctx.model.arch == LLM_ARCH_GEMMA4 || lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
lctx.model.arch == LLM_ARCH_QWEN35MOE ||
lctx.model.arch == LLM_ARCH_GEMMA4 ||
lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
if (cparams.embeddings || has_mtp) {
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (use_raw_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) {
@ -4781,6 +4783,7 @@ static int llama_decode_internal(
}
if (strcmp(gf->nodes[i]->name, "result_norm") == 0) {
embd = gf->nodes[i];
break;
}
}
}