Graph parallel for Gemma4 MoE (#1600)

* Use build_std_attention for Gemma4 when possible

It is possible for the 26b MoE and 31b dense models.
It is not possible for the E4B/E2B vaiants because they
don't have KV cache in each layer.

* Standardize Gemma4 dense ffn

* WIP: Gemma4 split mode graph

Runs but produces NaNs

* WIP: Gemma4 split mode graph

Runs but very high PPL. At least it is no longer NaN.

* WIP

* This works!

* Put attn_norm, attn_post_norm, ffn_norm, ffn_post_norm on all GPUs

* Fix crash when saving/loading KV cache

* WIP: split mode graph for Gemma4-MoE - crashes

* Split mode graph for Gemma4-MoE - this works

* Disable SWA optimization

Something goes wrong there

* Consolidate MoE and dense graph parallel
This commit is contained in:
Kawrakow 2026-04-09 14:07:29 +02:00 committed by GitHub
parent 9db5d9907e
commit 847e191936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 15 deletions

View File

@ -40,6 +40,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1]));
int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE);
int first = K->ne[1] - nton;
local_dst = *dst;
local_dst.op_params[4] = 0;
if (first > 0) {
local_dst = *dst;
Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first;
@ -51,6 +53,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
local_dst.op_params[4] = 0;
dst = &local_dst;
}
dst = &local_dst;
}
// On AMD the tile kernels perform poorly, use the vec kernel instead:

View File

@ -6042,13 +6042,18 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
GGML_ASSERT(cparams.flash_attn);
auto gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false);
bool is_moe = hparams.n_expert > 0;
std::vector<ggml_tensor *> sa_inp(n_device, nullptr);
std::vector<ggml_tensor *> sa_out(n_device, nullptr);
std::vector<ggml_tensor *> ffn_inp(n_device, nullptr);
std::vector<ggml_tensor *> ffn_out(n_device, nullptr);
std::vector<ggml_tensor *> ffn_out_moe;
if (is_moe) {
ffn_out_moe.resize(n_device, nullptr);
}
//ggml_tensor * last_ffn_inp = nullptr;
//ggml_tensor * last_sa_inp = nullptr;
ggml_tensor * inpL_moe = nullptr;
for (int il = 0; il < hparams.n_layer; ++il) {
auto & l = model.layers[il];
@ -6057,9 +6062,6 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
const float freq_scale_l = is_sliding ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
const int n_rot_l = is_sliding ? hparams.n_rot_swa : hparams.n_rot;
const int n_swa = is_sliding ? hparams.n_swa : 0;
//const int n_embd_head = hparams.n_embd_head_k(il);
//const int n_head = hparams.n_head(il);
//const int n_head_kv = hparams.n_head_kv(il);
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
@ -6099,8 +6101,18 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
} else {
GGML_ASSERT(inpL->op == GGML_OP_REDUCE);
auto cur = get_input_tensor_sm_graph(ctx0, inpL, id);
if (is_moe) {
GGML_ASSERT(inpL_moe && inpL_moe->op == GGML_OP_REDUCE);
cur = do_split_norm(ctx0, cur, model.layers[il-1].ffn_post_norm_1, hparams, cb, id, il_cb, false);
cb(cur, "ffn_post_norm", il_cb);
auto cur_moe = get_input_tensor_sm_graph(ctx0, inpL_moe, id);
cur_moe = do_split_norm(ctx0, cur_moe, model.layers[il-1].ffn_post_norm_2, hparams, cb, id, il_cb, false);
cb(cur, "ffn_moe_post_norm", il_cb);
cur = ggml_add(ctx0, cur, cur_moe);
cb(cur, "ffn_combined", il_cb);
}
cur = do_split_norm(ctx0, cur, model.layers[il-1].ffn_post_norm, hparams, cb, id, il_cb, false);
cb(cur, "ffn_post_norm", il_cb);
cb(cur, "ffn_normed", il_cb);
auto add = ffn_inp[id];
if (!add) {
for (int j = 0; j < n_device; ++j) {
@ -6214,6 +6226,11 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
ggml_row_size(split_vl->type, n_embd_head_v), 0);
cb(v, "v", il_cb);
//if (il == 0 || il == 5) {
// if (il == 0 && id == 0) printf("\n");
// if (id == 0) printf("--- il = %d\n", il);
// printf("id = %d, q: %ld x %ld x %ld x %ld, k: %ld x %ld x %ld x %ld v: %ld x %ld x %ld x %ld\n", id, q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3], v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
//}
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask_l, hparams.f_attention_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "fa", il_cb);
@ -6285,12 +6302,43 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
}
ggml_build_forward_expand(gf, cur);
ffn_out[id] = cur;
if (is_moe) {
cur = do_split_norm(ctx0, ffn_inp[id], model.layers[il].ffn_pre_norm_2, hparams, cb, id, il_cb, false);
auto tmp = ggml_rms_norm(ctx0, ffn_inp[id], hparams.f_norm_rms_eps);
tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) hparams.n_embd));
tmp = ggml_mul(ctx0, tmp, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp_s->extra)->splits[id]);
auto logits = llm.llm_build_lora_mm(lctx, ctx0, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp->extra)->splits[id], tmp);
auto moe = llm. llm_build_moe_ffn(ctx0, lctx, cur,
nullptr, nullptr, nullptr,
((const ggml_split_tensor_t *)model.layers[il].ffn_down_exps->extra)->splits[id], nullptr,
llm.n_expert, llm.n_expert_used,
LLM_FFN_GELU, true, false, 0.0f,
LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf, false,
((const ggml_split_tensor_t *)model.layers[il].ffn_up_gate_exps->extra)->splits[id],
nullptr, logits, ((const ggml_split_tensor_t *)model.layers[il].ffn_down_exps_s->extra)->splits[id]);
if (moe->ne[1] > 32 && cparams.reduce_type != GGML_TYPE_F32) {
moe = ggml_cast(ctx0, moe, cparams.reduce_type);
cb(moe, "ffn_moe_cast", il_cb);
}
ggml_build_forward_expand(gf, moe);
ffn_out_moe[id] = moe;
}
}
inpL = ggml_reduce(ctx0, ffn_out.data(), n_device, GGML_OP_ADD);
cb(inpL, "ffn_reduce", il);
ggml_build_forward_expand(gf, inpL);
if (is_moe) {
inpL_moe = ggml_reduce(ctx0, ffn_out_moe.data(), n_device, GGML_OP_ADD);
cb(inpL_moe, "ffn_moe_reduce", il);
ggml_build_forward_expand(gf, inpL_moe);
}
}
int idx = lctx.model.default_layer_device[lctx.model.hparams.n_layer];
@ -6301,9 +6349,25 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
cur = inpL->view_src;
}
auto post_norm = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm->extra;
auto post_norm = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm->extra;
if (is_moe) {
auto cur_moe = inpL_moe->src[idx];
if (!cur_moe) {
cur_moe = inpL_moe->view_src;
}
auto post_norm_1 = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm_1->extra;
auto post_norm_2 = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm_2->extra;
cur = llm.llm_build_norm(ctx0, cur, hparams, post_norm_1->splits[idx], NULL, LLM_NORM_RMS, cb, -1);
cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
ggml_build_forward_expand(gf, cur);
cur_moe = llm.llm_build_norm(ctx0, cur_moe, hparams, post_norm_2->splits[idx], NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "ffn_post", hparams.n_layer-1);
cb(cur_moe, "ffn_post_moe", hparams.n_layer-1);
cur = ggml_add(ctx0, cur, cur_moe);
cb(cur, "ffn_combined", hparams.n_layer-1);
}
cur = llm.llm_build_norm(ctx0, cur, hparams, post_norm->splits[idx], NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "ffn_post", hparams.n_layer-1);
cb(cur, "ffn_normed", hparams.n_layer-1);
auto add = ffn_inp[idx];
if (!add) {
for (int j = 0; j < n_device; ++j) {
@ -6322,6 +6386,7 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
}
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
cb(cur, "almost_result", -1);
if (hparams.f_final_logit_softcapping > 0) {
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
}
@ -6405,7 +6470,7 @@ ggml_cgraph * llm_build_context::build_gemma4() {
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
return build_gemma4_graph_paralle(*this, lctx, ctx0, inpL, inp_pos, inp_out_ids,
KQ_mask, KQ_mask_swa, n_tokens, cb);
KQ_mask, KQ_mask_swa, n_tokens, cb);
}
auto gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false);

View File

@ -4057,12 +4057,6 @@ bool create_tensors_helper::create_tensors() {
if (model.tok_embd_per_layer) {
supported = false;
}
for (auto & l : model.layers) {
if (l.ffn_gate_inp) {
supported = false;
break;
}
}
if (!supported) {
LLAMA_LOG_WARN("\n=========================================================\n");
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this Gemma4 variant\n");
@ -4268,6 +4262,31 @@ bool create_tensors_helper::create_tensors() {
prepare_split_tensors(-1, ctx_split, layer.ffn_post_norm, layer.split_ffn_post_norm, mirror, mem_used);
}
}
if (layer.ffn_post_norm_1) {
if (auto it = split_tensors.find(layer.ffn_post_norm_1); it != split_tensors.end()) {
prepare_split_tensors(-1, ctx_split, layer.ffn_post_norm_1, layer.split_ffn_post_norm_1, mirror, mem_used);
}
}
if (layer.ffn_post_norm_2) {
if (auto it = split_tensors.find(layer.ffn_post_norm_2); it != split_tensors.end()) {
prepare_split_tensors(-1, ctx_split, layer.ffn_post_norm_2, layer.split_ffn_post_norm_2, mirror, mem_used);
}
}
if (layer.ffn_pre_norm_2) {
if (auto it = split_tensors.find(layer.ffn_pre_norm_2); it != split_tensors.end()) {
prepare_split_tensors(-1, ctx_split, layer.ffn_pre_norm_2, layer.split_ffn_pre_norm_2, mirror, mem_used);
}
}
if (layer.ffn_down_exps_s) {
if (auto it = split_tensors.find(layer.ffn_down_exps_s); it != split_tensors.end()) {
prepare_split_tensors(-1, ctx_split, layer.ffn_down_exps_s, layer.split_ffn_down_exps_s, mirror, mem_used);
}
}
if (layer.ffn_gate_inp_s) {
if (auto it = split_tensors.find(layer.ffn_gate_inp_s); it != split_tensors.end()) {
prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp_s, layer.split_ffn_gate_inp_s, mirror, mem_used);
}
}
if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) {
bool use_split = split_tensors.find(layer.ffn_down) != split_tensors.end() &&

View File

@ -259,6 +259,9 @@ struct llama_layer {
llama_split_tensor split_ffn_norm;
llama_split_tensor split_ffn_up_gate;
llama_split_tensor split_ffn_post_norm;
llama_split_tensor split_ffn_post_norm_1;
llama_split_tensor split_ffn_post_norm_2;
llama_split_tensor split_ffn_pre_norm_2;
// ff MoE
struct ggml_tensor * ffn_gate_inp = nullptr;
@ -300,6 +303,8 @@ struct llama_layer {
llama_split_tensor split_ffn_down_exps_b;
llama_split_tensor split_ffn_up_exps_b;
llama_split_tensor split_ffn_up_gate_exps_b;
llama_split_tensor split_ffn_down_exps_s;
llama_split_tensor split_ffn_gate_inp_s;
// ff bias
struct ggml_tensor * ffn_gate_b = nullptr;