mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Enable split mode graph for Gemma4-12B
This commit is contained in:
parent
19dcc1f7d1
commit
68a94ab930
@ -2147,7 +2147,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
if (Q->ne[1] <= 32/ncols2 || (DKQ == 512 && ncols2 == 16)) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@ -2260,7 +2260,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
// (covers 8 and 16) through the ncols2=8 kernel. It iterates over Q-head groups
|
||||
// (iter_z = ceil(gqa_ratio/ncols2)), so 16 heads run as two passes of 8. This unblocks
|
||||
// head_dim-512 models with a 16:1 GQA ratio such as Gemma 4 12B's global layers.
|
||||
if (gqa_ratio % 8 == 0) {
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 16>(ctx, dst);
|
||||
}
|
||||
else if (gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 8>(ctx, dst);
|
||||
}
|
||||
else if (gqa_ratio % 4 == 0) {
|
||||
|
||||
@ -213,6 +213,8 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
|
||||
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
|
||||
GGML_ASSERT(kl && vl);
|
||||
|
||||
int nhave = 0;
|
||||
ggml_tensor * sa_last = nullptr;
|
||||
for (int id = 0; id < n_device; ++id) {
|
||||
GGML_ASSERT((wq->splits[id] && wk->splits[id] && (!wv || wv->splits[id]) && wo->splits[id]) ||
|
||||
(!wq->splits[id] && !wk->splits[id] && (!wv || !wv->splits[id]) && !wo->splits[id]));
|
||||
@ -379,10 +381,12 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
sa_out[id] = cur;
|
||||
sa_last = cur;
|
||||
++nhave;
|
||||
|
||||
}
|
||||
|
||||
auto last_ffn_inp = ggml_reduce(ctx0, sa_out.data(), n_device, GGML_OP_ADD);
|
||||
auto last_ffn_inp = nhave > 1 ? ggml_reduce(ctx0, sa_out.data(), n_device, GGML_OP_ADD) : sa_last;
|
||||
ggml_build_forward_expand(gf, last_ffn_inp);
|
||||
cb(last_ffn_inp, "sa_reduce", il);
|
||||
|
||||
@ -403,7 +407,7 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
|
||||
}
|
||||
int il_cb = 1000*(il + 1) + id;
|
||||
|
||||
GGML_ASSERT(last_ffn_inp && last_ffn_inp->op == GGML_OP_REDUCE);
|
||||
GGML_ASSERT(last_ffn_inp && (nhave == 1 || last_ffn_inp->op == GGML_OP_REDUCE));
|
||||
auto cur = llm_build_context::get_input_tensor_sm_graph(ctx0, last_ffn_inp, id);
|
||||
cur = llm_build_context::do_split_norm(ctx0, cur, model.layers[il].attn_post_norm, hparams, cb, id, il_cb, false);
|
||||
cb(cur, "sa_post", il_cb);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user