CUDA FA: cover Gemma4-4B/2B assistant (#1934)

This commit is contained in:
Kawrakow 2026-06-08 08:18:26 +02:00 committed by GitHub
parent b50b0919d5
commit 1660459db5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2122,7 +2122,10 @@ template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if constexpr ((DKQ == 576 || DKQ == 512) && ncols2 <= 4) {
if constexpr (DKQ == 512 && ncols2 == 2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8, ncols2>(ctx, dst);
}
else if constexpr ((DKQ == 576 || DKQ == 512) && ncols2 <= 4) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, ncols2>(ctx, dst);
} else {
@ -2269,6 +2272,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
else if (gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 4>(ctx, dst);
}
else if (gqa_ratio % 2 == 0) {
// Gemma4-4B assistant
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<512, 512, 2>(ctx, dst);
}
else {
GGML_ABORT("Fatal error");
}