mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix Qwen3.6-MoE low MTP acceptance rate (#1815)
* Fix Qwen3.6-MoE low MTP acceptance rate * Fix Gemma4 MTP
This commit is contained in:
parent
c35189d83c
commit
a407b9ca3d
@ -663,7 +663,7 @@ ggml_cgraph * llm_build_context::build_gemma4_mtp() {
|
||||
// not required for correct inference — the full-vocab matmul against the tied output
|
||||
// weight still yields valid per-token logits.
|
||||
{
|
||||
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb, false);
|
||||
cb(logits, "result_output", -1);
|
||||
}
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
|
||||
@ -438,9 +438,9 @@ ggml_cgraph * llm_build_context::append_pooling(struct ggml_cgraph * gf) {
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
inp = gf->nodes[i];
|
||||
|
||||
if (strcmp(inp->name, "result_norm") == 0 ||
|
||||
strcmp(inp->name, "result_embd") == 0 ||
|
||||
strcmp(inp->name, "output_normed") == 0) {
|
||||
if (strcmp(inp->name, "result_norm") == 0 ||
|
||||
strcmp(inp->name, "result_embd") == 0 ||
|
||||
strcmp(inp->name, "output_normed") == 0) {
|
||||
break;
|
||||
}
|
||||
inp = nullptr;
|
||||
@ -2048,25 +2048,30 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
|
||||
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb) {
|
||||
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name) {
|
||||
// lm_head
|
||||
if (output->extra) {
|
||||
auto split_output = (ggml_split_tensor_t *)output->extra;
|
||||
auto split_output_norm = output_norm && output_norm->extra ? (ggml_split_tensor_t *)output_norm->extra : nullptr;
|
||||
std::vector<ggml_tensor *> o;
|
||||
o.reserve(split_output->n_device);
|
||||
ggml_tensor * last_norm = nullptr;
|
||||
for (int id = 0; id < split_output->n_device; ++id) {
|
||||
auto split = split_output->splits[id];
|
||||
if (!split) continue;
|
||||
if (output_norm) {
|
||||
auto the_norm = split_output_norm ? split_output_norm->splits[id] : output_norm;
|
||||
auto cur_normed = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, the_norm, NULL, LLM_NORM_RMS, cb, -1);
|
||||
last_norm = cur_normed;
|
||||
cb(cur_normed, "result_norm", 1000*(id+1));
|
||||
o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur_normed));
|
||||
} else {
|
||||
o.push_back(llm_build_context::llm_build_lora_mm(lctx, ctx, split, cur));
|
||||
}
|
||||
cb(o.back(), "output", id);
|
||||
if (add_normed_name && last_norm) {
|
||||
cb(last_norm, "result_norm", -1);
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(!o.empty());
|
||||
if (o.size() == 1) {
|
||||
@ -2090,7 +2095,9 @@ ggml_tensor * llm_build_context::build_output(llama_context & lctx, ggml_context
|
||||
}
|
||||
if (output_norm) {
|
||||
cur = llm_build_context::llm_build_norm(ctx, cur, lctx.model.hparams, output_norm, NULL, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
if (add_normed_name) {
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
}
|
||||
cur = llm_build_context::llm_build_lora_mm(lctx, ctx, output, cur);
|
||||
}
|
||||
|
||||
@ -448,7 +448,7 @@ llm_expert_gating_func_type gating_op,
|
||||
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * output, const llm_build_cb & cb);
|
||||
|
||||
static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml_tensor * cur,
|
||||
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb);
|
||||
ggml_tensor * output, ggml_tensor * output_norm, const llm_build_cb & cb, bool add_normed_name = true);
|
||||
|
||||
static ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * the_norm, const llama_hparams & hparams,
|
||||
const llm_build_cb & cb, int id, int il_cb, bool is_norm);
|
||||
|
||||
@ -1612,7 +1612,7 @@ static void restore_recurrent_cache_tensors(int step, ggml_backend_sched_t sched
|
||||
size_t ssm_bytes, size_t conv_bytes,
|
||||
ggml_tensor * s_l, ggml_tensor * per_step_ssm, ggml_tensor * per_step_conv,
|
||||
std::unordered_set<ggml_backend_t> & backends_to_sync) {
|
||||
auto dst_backend = ggml_backend_sched_get_tensor_backend(sched, s_l);
|
||||
auto dst_backend = ggml_backend_sched_get_backend(sched, ggml_backend_sched_get_backend_idx(sched, s_l->buffer));
|
||||
auto dst = *s_l;
|
||||
dst.ne[0] = ssm_bytes/sizeof(float);
|
||||
dst.nb[1] = dst.nb[2] = dst.nb[3] = ssm_bytes + conv_bytes;
|
||||
@ -4766,8 +4766,10 @@ static int llama_decode_internal(
|
||||
}
|
||||
else {
|
||||
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
|
||||
const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
|
||||
lctx.model.arch == LLM_ARCH_QWEN35MOE || lctx.model.arch == LLM_ARCH_GEMMA4 || lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
|
||||
const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
|
||||
lctx.model.arch == LLM_ARCH_QWEN35MOE ||
|
||||
lctx.model.arch == LLM_ARCH_GEMMA4 ||
|
||||
lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
|
||||
if (cparams.embeddings || has_mtp) {
|
||||
for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||
if (use_raw_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) {
|
||||
@ -4781,6 +4783,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
if (strcmp(gf->nodes[i]->name, "result_norm") == 0) {
|
||||
embd = gf->nodes[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user