Fix Gemma4-MoE graph parallel (#1604)

This commit is contained in:
Kawrakow 2026-04-09 17:31:09 +02:00 committed by GitHub
parent 557b674f63
commit 13d7178db9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -754,7 +754,6 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s &&
up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR &&
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
//printf("%s: %s\n", __func__, ggml_op_name(input->op));
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
auto u = (ggml_split_tensor_t *)up->extra;
@ -778,7 +777,6 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
if (lctx.model.arch == LLM_ARCH_STEP35) {
//printf("%s(%d): limits = %g\n", __func__, il, lctx.model.hparams.swiglu_limits[il]);
*(float *)(cur->op_params + 1) = lctx.model.hparams.swiglu_limits[il];
}
cur = llm_build_lora_mm(lctx, ctx, split_d, cur);
@ -2160,7 +2158,6 @@ ggml_cgraph * llm_build_context::build_llama() {
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
this_n_swa);
}
//printf("%s: attn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op));
if (il == n_layer - 1 && !use_rope && inp_out_ids) {
// skip computing output for unused tokens
@ -2243,7 +2240,6 @@ ggml_cgraph * llm_build_context::build_llama() {
cb, il, gf, true);
cb(cur, "ffn_moe_out", il);
}
//printf("%s: ffn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op));
// For Granite architecture
if (hparams.f_residual_scale) {
@ -6226,11 +6222,6 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
ggml_row_size(split_vl->type, n_embd_head_v), 0);
cb(v, "v", il_cb);
//if (il == 0 || il == 5) {
// if (il == 0 && id == 0) printf("\n");
// if (id == 0) printf("--- il = %d\n", il);
// printf("id = %d, q: %ld x %ld x %ld x %ld, k: %ld x %ld x %ld x %ld v: %ld x %ld x %ld x %ld\n", id, q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3], v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
//}
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask_l, hparams.f_attention_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "fa", il_cb);
@ -6269,6 +6260,9 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
(!ffn_up->splits[id] && !ffn_gate->splits[id] && !ffn_down->splits[id]));
if (!ffn_up->splits[id]) {
ffn_inp[id] = ffn_out[id] = nullptr;
if (is_moe) {
ffn_out_moe[id] = nullptr;
}
continue;
}
int il_cb = 1000*(il + 1) + id;
@ -6641,10 +6635,6 @@ ggml_cgraph * llm_build_context::build_gemma4() {
// layer_scalar
if (model.layers[il].out_scale) {
//if (ggml_backend_buffer_is_host(model.layers[il].out_scale->buffer)) {
// auto val = (const float *)model.layers[il].out_scale->data;
// printf("Layer %d: out_scale = %g\n", il, val[0]);
//}
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
cb(cur, "out_scaled", il);
}
@ -7673,7 +7663,6 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
break;
}
}
//printf("Using n_max_head = %d -> kv_f32_size = %zu\n", n_max_head, kv_f32_size/(n_head/n_max_head));
}
GGML_ASSERT(n_head % n_max_head == 0);
@ -9039,7 +9028,6 @@ ggml_cgraph * llm_build_context::build_chatglm() {
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,