From 13d7178db98217dd7e38d78bb206091ac6dab6ce Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 9 Apr 2026 17:31:09 +0200 Subject: [PATCH] Fix Gemma4-MoE graph parallel (#1604) --- src/llama-build-context.cpp | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 32eba1ae..69e09993 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -754,7 +754,6 @@ ggml_tensor * llm_build_context::llm_build_ffn( if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s && up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { - //printf("%s: %s\n", __func__, ggml_op_name(input->op)); auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU; auto u = (ggml_split_tensor_t *)up->extra; @@ -778,7 +777,6 @@ ggml_tensor * llm_build_context::llm_build_ffn( cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op); cb(cur, "ffn_up_gate", il_cb); if (lctx.model.arch == LLM_ARCH_STEP35) { - //printf("%s(%d): limits = %g\n", __func__, il, lctx.model.hparams.swiglu_limits[il]); *(float *)(cur->op_params + 1) = lctx.model.hparams.swiglu_limits[il]; } cur = llm_build_lora_mm(lctx, ctx, split_d, cur); @@ -2160,7 +2158,6 @@ ggml_cgraph * llm_build_context::build_llama() { Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, this_n_swa); } - //printf("%s: attn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op)); if (il == n_layer - 1 && !use_rope && inp_out_ids) { // skip computing output for unused tokens @@ -2243,7 +2240,6 @@ ggml_cgraph * llm_build_context::build_llama() { cb, il, gf, true); cb(cur, "ffn_moe_out", il); } - //printf("%s: ffn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op)); // For Granite architecture if (hparams.f_residual_scale) { @@ -6226,11 +6222,6 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c ggml_row_size(split_vl->type, n_embd_head_v), 0); cb(v, "v", il_cb); - //if (il == 0 || il == 5) { - // if (il == 0 && id == 0) printf("\n"); - // if (id == 0) printf("--- il = %d\n", il); - // printf("id = %d, q: %ld x %ld x %ld x %ld, k: %ld x %ld x %ld x %ld v: %ld x %ld x %ld x %ld\n", id, q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3], v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //} cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask_l, hparams.f_attention_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, "fa", il_cb); @@ -6269,6 +6260,9 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c (!ffn_up->splits[id] && !ffn_gate->splits[id] && !ffn_down->splits[id])); if (!ffn_up->splits[id]) { ffn_inp[id] = ffn_out[id] = nullptr; + if (is_moe) { + ffn_out_moe[id] = nullptr; + } continue; } int il_cb = 1000*(il + 1) + id; @@ -6641,10 +6635,6 @@ ggml_cgraph * llm_build_context::build_gemma4() { // layer_scalar if (model.layers[il].out_scale) { - //if (ggml_backend_buffer_is_host(model.layers[il].out_scale->buffer)) { - // auto val = (const float *)model.layers[il].out_scale->data; - // printf("Layer %d: out_scale = %g\n", il, val[0]); - //} cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); cb(cur, "out_scaled", il); } @@ -7673,7 +7663,6 @@ ggml_cgraph * llm_build_context::build_deepseek2() { break; } } - //printf("Using n_max_head = %d -> kv_f32_size = %zu\n", n_max_head, kv_f32_size/(n_head/n_max_head)); } GGML_ASSERT(n_head % n_max_head == 0); @@ -9039,7 +9028,6 @@ ggml_cgraph * llm_build_context::build_chatglm() { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,