diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 0ab38d77..a9a83a3e 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -2122,7 +2122,10 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - if constexpr ((DKQ == 576 || DKQ == 512) && ncols2 <= 4) { + if constexpr (DKQ == 512 && ncols2 == 2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + } + else if constexpr ((DKQ == 576 || DKQ == 512) && ncols2 <= 4) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); } else { @@ -2269,6 +2272,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens else if (gqa_ratio % 4 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 4>(ctx, dst); } + else if (gqa_ratio % 2 == 0) { + // Gemma4-4B assistant + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 2>(ctx, dst); + } else { GGML_ABORT("Fatal error"); }