mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Merge pull request #1973 from ikawrakow/ik/fattn_mma_gqa_16
CUDA FA: faster TG when GQA is 16 and head size is 128
This commit is contained in:
commit
2f524850a1
@ -223,6 +223,19 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (use_gqa_opt && gqa_ratio % 16 == 0 && Q->ne[0] == 128 && Q->ne[0] <= 8) {
|
||||||
|
if (Q->ne[1] <= 1) {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_case<128, 1, 16>(ctx, dst);
|
||||||
|
}
|
||||||
|
else if (Q->ne[1] <= 2) {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_case<128, 2, 16>(ctx, dst);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_case<128, 4, 16>(ctx, dst);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
|||||||
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
||||||
|
|
||||||
|
DECL_FATTN_MMA_F16_CASE(128, 1, 16);
|
||||||
|
|||||||
@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
|||||||
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
||||||
|
|
||||||
|
DECL_FATTN_MMA_F16_CASE(128, 2, 16);
|
||||||
|
|||||||
@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
|||||||
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
||||||
|
|
||||||
|
DECL_FATTN_MMA_F16_CASE(128, 4, 16);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user