From f4125e8b1f1de07ac07fbcdcca34d3a8f89f65e2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 24 Mar 2026 07:53:16 +0100 Subject: [PATCH] Slightly better CPU performance for SWA models (#1496) --- ggml/src/iqk/fa/iqk_fa_templates.h | 53 +++++++++++++++++------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 648e7308..aec79c0d 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1364,12 +1364,14 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in for (; ik >=0 && Mc[ik] != 0; ik -= k_step); ik += k_step; for (int k1 = 0; k1 < ik/k_step; ++k1) { + if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else - KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); + KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); #endif - fqkv.accumulate_qkv(vh, fms); + fqkv.accumulate_qkv(vh, fms); + } kh.next_block(k_step); vh.next_block(k_step); mr += k_step*sizeof(ggml_half); @@ -1429,9 +1431,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, for (; ik >=0 && Mc[ik] != 0; ik -= k_step); ik += k_step; for (int k1 = 0; k1 < ik/k_step; ++k1) { - HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); - KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); - fqkv.accumulate_qkv(vh, fms); + Mc = (const uint16_t *)mr; + if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) { + HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); + fqkv.accumulate_qkv(vh, fms); + } kh.next_block(k_step); vh.next_block(k_step); mr += k_step*sizeof(ggml_half); @@ -1460,17 +1465,20 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, for (; ik >=0 && Mc[ik] != 0; ik -= k_step); ik += k_step; for (int k1 = 0; k1 < ik/k_step; ++k1) { + Mc = (const uint16_t *)mr; + if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) { #if FA_TIMING - t1 = Perf::cur_time(); - KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); - perf.accum_nolock(1, t1); - t1 = Perf::cur_time(); - fqkv.accumulate_qkv(vh, fms); - perf.accum_nolock(2, t1); + t1 = Perf::cur_time(); + KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); + perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); + fqkv.accumulate_qkv(vh, fms); + perf.accum_nolock(2, t1); #else - KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); - fqkv.accumulate_qkv(vh, fms); + KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); #endif + } kh.next_block(k_step); vh.next_block(k_step); mr += k_step*sizeof(ggml_half); @@ -1982,17 +1990,18 @@ struct FlashAttnBF16 { for (; ik >=0 && Mc[ik] != 0; ik -= k_step); ik += k_step; for (int k1 = 0; k1 < ik/k_step; ++k1) { + Mc = (const uint16_t *)mr; + if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) { #if FA_TIMING - //t1 = Perf::cur_time(); - FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); - //perf.accum_nolock(1, t1); - t1 = Perf::cur_time(); - fqkv.accumulate_qkv(vh, fms); - perf.accum_nolock(3, t1); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + t1 = Perf::cur_time(); + fqkv.accumulate_qkv(vh, fms); + perf.accum_nolock(3, t1); #else - FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); - fqkv.accumulate_qkv(vh, fms); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + fqkv.accumulate_qkv(vh, fms); #endif + } kh.next_block(k_step); vh.next_block(k_step); mr += k_step*sizeof(ggml_half);