Enable split mode graph for Gemma4-12B

This commit is contained in:
Kawrakow 2026-06-04 16:12:53 +00:00
parent 19dcc1f7d1
commit 68a94ab930
2 changed files with 11 additions and 4 deletions

View File

@ -2147,7 +2147,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
return;
}
if (Q->ne[1] <= 32/ncols2) {
if (Q->ne[1] <= 32/ncols2 || (DKQ == 512 && ncols2 == 16)) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
@ -2260,7 +2260,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
// (covers 8 and 16) through the ncols2=8 kernel. It iterates over Q-head groups
// (iter_z = ceil(gqa_ratio/ncols2)), so 16 heads run as two passes of 8. This unblocks
// head_dim-512 models with a 16:1 GQA ratio such as Gemma 4 12B's global layers.
if (gqa_ratio % 8 == 0) {
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 16>(ctx, dst);
}
else if (gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 8>(ctx, dst);
}
else if (gqa_ratio % 4 == 0) {

View File

@ -213,6 +213,8 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
GGML_ASSERT(kl && vl);
int nhave = 0;
ggml_tensor * sa_last = nullptr;
for (int id = 0; id < n_device; ++id) {
GGML_ASSERT((wq->splits[id] && wk->splits[id] && (!wv || wv->splits[id]) && wo->splits[id]) ||
(!wq->splits[id] && !wk->splits[id] && (!wv || !wv->splits[id]) && !wo->splits[id]));
@ -379,10 +381,12 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
}
ggml_build_forward_expand(gf, cur);
sa_out[id] = cur;
sa_last = cur;
++nhave;
}
auto last_ffn_inp = ggml_reduce(ctx0, sa_out.data(), n_device, GGML_OP_ADD);
auto last_ffn_inp = nhave > 1 ? ggml_reduce(ctx0, sa_out.data(), n_device, GGML_OP_ADD) : sa_last;
ggml_build_forward_expand(gf, last_ffn_inp);
cb(last_ffn_inp, "sa_reduce", il);
@ -403,7 +407,7 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
}
int il_cb = 1000*(il + 1) + id;
GGML_ASSERT(last_ffn_inp && last_ffn_inp->op == GGML_OP_REDUCE);
GGML_ASSERT(last_ffn_inp && (nhave == 1 || last_ffn_inp->op == GGML_OP_REDUCE));
auto cur = llm_build_context::get_input_tensor_sm_graph(ctx0, last_ffn_inp, id);
cur = llm_build_context::do_split_norm(ctx0, cur, model.layers[il].attn_post_norm, hparams, cb, id, il_cb, false);
cb(cur, "sa_post", il_cb);