From 6be3a488d32229f0b07e3034abf43ee1d7dae097 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 15 Jun 2026 11:46:02 +0000 Subject: [PATCH] CUDA FA: faster TG when GQA is 16 and head size is 128 --- ggml/src/ggml-cuda/fattn-mma-f16.cu | 13 +++++++++++++ .../fattn-mma-f16-instance-ncols1_1-ncols2_8.cu | 2 ++ .../fattn-mma-f16-instance-ncols1_2-ncols2_8.cu | 2 ++ .../fattn-mma-f16-instance-ncols1_4-ncols2_8.cu | 2 ++ 4 files changed, 19 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu index 62c52132..50003e34 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -223,6 +223,19 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens return; } + if (use_gqa_opt && gqa_ratio % 16 == 0 && Q->ne[0] == 128 && Q->ne[0] <= 8) { + if (Q->ne[1] <= 1) { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 1, 16>(ctx, dst); + } + else if (Q->ne[1] <= 2) { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 2, 16>(ctx, dst); + } + else { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 4, 16>(ctx, dst); + } + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); return; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 80108615..e0707acd 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 1, 8); + +DECL_FATTN_MMA_F16_CASE(128, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 617464c9..7585b49f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 2, 8); + +DECL_FATTN_MMA_F16_CASE(128, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index e8cb0e1b..756db077 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 4, 8); + +DECL_FATTN_MMA_F16_CASE(128, 4, 16);