From 68a94ab93003ecdf2108cdc285feff82a68fd6a0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 4 Jun 2026 16:12:53 +0000 Subject: [PATCH] Enable split mode graph for Gemma4-12B --- ggml/src/ggml-cuda/fattn-new-mma.cu | 7 +++++-- src/graphs/build_gemma4.cpp | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 8c60b563..0ab38d77 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -2147,7 +2147,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con return; } - if (Q->ne[1] <= 32/ncols2) { + if (Q->ne[1] <= 32/ncols2 || (DKQ == 512 && ncols2 == 16)) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -2260,7 +2260,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens // (covers 8 and 16) through the ncols2=8 kernel. It iterates over Q-head groups // (iter_z = ceil(gqa_ratio/ncols2)), so 16 heads run as two passes of 8. This unblocks // head_dim-512 models with a 16:1 GQA ratio such as Gemma 4 12B's global layers. - if (gqa_ratio % 8 == 0) { + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 16>(ctx, dst); + } + else if (gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 8>(ctx, dst); } else if (gqa_ratio % 4 == 0) { diff --git a/src/graphs/build_gemma4.cpp b/src/graphs/build_gemma4.cpp index 8a2c8029..64d0a28d 100644 --- a/src/graphs/build_gemma4.cpp +++ b/src/graphs/build_gemma4.cpp @@ -213,6 +213,8 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_ auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra; GGML_ASSERT(kl && vl); + int nhave = 0; + ggml_tensor * sa_last = nullptr; for (int id = 0; id < n_device; ++id) { GGML_ASSERT((wq->splits[id] && wk->splits[id] && (!wv || wv->splits[id]) && wo->splits[id]) || (!wq->splits[id] && !wk->splits[id] && (!wv || !wv->splits[id]) && !wo->splits[id])); @@ -379,10 +381,12 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_ } ggml_build_forward_expand(gf, cur); sa_out[id] = cur; + sa_last = cur; + ++nhave; } - auto last_ffn_inp = ggml_reduce(ctx0, sa_out.data(), n_device, GGML_OP_ADD); + auto last_ffn_inp = nhave > 1 ? ggml_reduce(ctx0, sa_out.data(), n_device, GGML_OP_ADD) : sa_last; ggml_build_forward_expand(gf, last_ffn_inp); cb(last_ffn_inp, "sa_reduce", il); @@ -403,7 +407,7 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_ } int il_cb = 1000*(il + 1) + id; - GGML_ASSERT(last_ffn_inp && last_ffn_inp->op == GGML_OP_REDUCE); + GGML_ASSERT(last_ffn_inp && (nhave == 1 || last_ffn_inp->op == GGML_OP_REDUCE)); auto cur = llm_build_context::get_input_tensor_sm_graph(ctx0, last_ffn_inp, id); cur = llm_build_context::do_split_norm(ctx0, cur, model.layers[il].attn_post_norm, hparams, cb, id, il_cb, false); cb(cur, "sa_post", il_cb);