Slightly better CPU performance for SWA models (#1496)

This commit is contained in:
Kawrakow 2026-03-24 07:53:16 +01:00 committed by GitHub
parent bbc07002f7
commit f4125e8b1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1364,12 +1364,14 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
for (; ik >=0 && Mc[ik] != 0; ik -= k_step); for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step; ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) { for (int k1 = 0; k1 < ik/k_step; ++k1) {
if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) {
#ifdef __aarch64__ #ifdef __aarch64__
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
#else #else
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
#endif #endif
fqkv.accumulate_qkv(vh, fms); fqkv.accumulate_qkv(vh, fms);
}
kh.next_block(k_step); kh.next_block(k_step);
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
@ -1429,9 +1431,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
for (; ik >=0 && Mc[ik] != 0; ik -= k_step); for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step; ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) { for (int k1 = 0; k1 < ik/k_step; ++k1) {
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8); Mc = (const uint16_t *)mr;
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) {
fqkv.accumulate_qkv(vh, fms); HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
fqkv.accumulate_qkv(vh, fms);
}
kh.next_block(k_step); kh.next_block(k_step);
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
@ -1460,17 +1465,20 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
for (; ik >=0 && Mc[ik] != 0; ik -= k_step); for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step; ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) { for (int k1 = 0; k1 < ik/k_step; ++k1) {
Mc = (const uint16_t *)mr;
if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) {
#if FA_TIMING #if FA_TIMING
t1 = Perf::cur_time(); t1 = Perf::cur_time();
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
perf.accum_nolock(1, t1); perf.accum_nolock(1, t1);
t1 = Perf::cur_time(); t1 = Perf::cur_time();
fqkv.accumulate_qkv(vh, fms); fqkv.accumulate_qkv(vh, fms);
perf.accum_nolock(2, t1); perf.accum_nolock(2, t1);
#else #else
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
fqkv.accumulate_qkv(vh, fms); fqkv.accumulate_qkv(vh, fms);
#endif #endif
}
kh.next_block(k_step); kh.next_block(k_step);
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
@ -1982,17 +1990,18 @@ struct FlashAttnBF16 {
for (; ik >=0 && Mc[ik] != 0; ik -= k_step); for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step; ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) { for (int k1 = 0; k1 < ik/k_step; ++k1) {
Mc = (const uint16_t *)mr;
if (k1 == ik/k_step-1 || Mc[k_step-1] == 0) {
#if FA_TIMING #if FA_TIMING
//t1 = Perf::cur_time(); FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); t1 = Perf::cur_time();
//perf.accum_nolock(1, t1); fqkv.accumulate_qkv(vh, fms);
t1 = Perf::cur_time(); perf.accum_nolock(3, t1);
fqkv.accumulate_qkv(vh, fms);
perf.accum_nolock(3, t1);
#else #else
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms);
fqkv.accumulate_qkv(vh, fms); fqkv.accumulate_qkv(vh, fms);
#endif #endif
}
kh.next_block(k_step); kh.next_block(k_step);
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);