diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 3aaff08c..711ae7c0 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -32,6 +32,24 @@ namespace { +// Compute effective K boundary by scanning ALL query rows (union-of-masks). +// The original early-termination scanned only the last row, which is correct +// for single-slot (all rows have the same mask) but wrong for multi-slot +// parallel (--parallel N>1) where different slots have different sequence +// lengths and therefore different mask patterns. +// Returns the number of K elements to process (multiple of k_step). +inline int mask_effective_nk1(const char * mask, int n_rows, int stride_m, int nk1, int k_step) { + int ik_max = 0; + for (int j = 0; j < n_rows; ++j) { + auto Mc = (const uint16_t *)(mask + j * stride_m); + int ik = nk1 - k_step; + for (; ik >= 0 && Mc[ik] != 0; ik -= k_step); + ik += k_step; + if (ik > ik_max) ik_max = ik; + } + return ik_max; +} + struct BaseHelper { BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {} @@ -1359,11 +1377,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in KQHelper::convert(q_step, stride_q, q, q_f16); #endif auto mr = mask; - auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); - int ik = nk1 - k_step; - for (; ik >=0 && Mc[ik] != 0; ik -= k_step); - ik += k_step; - for (int k1 = 0; k1 < ik/k_step; ++k1) { + int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step); + for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) { #ifdef __aarch64__ KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else @@ -1424,11 +1439,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, auto q8r = (typename HelperQ80R8::block_q8 *)qptr; HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; - auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); - int ik = nk1 - k_step; - for (; ik >=0 && Mc[ik] != 0; ik -= k_step); - ik += k_step; - for (int k1 = 0; k1 < ik/k_step; ++k1) { + int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step); + for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) { HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); @@ -1455,11 +1467,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, perf.accum_nolock(0, t1); #endif auto mr = mask; - auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); - int ik = nk1 - k_step; - for (; ik >=0 && Mc[ik] != 0; ik -= k_step); - ik += k_step; - for (int k1 = 0; k1 < ik/k_step; ++k1) { + int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step); + for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) { #if FA_TIMING t1 = Perf::cur_time(); KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); @@ -1977,11 +1986,8 @@ struct FlashAttnBF16 { perf.accum_nolock(0, t1); #endif auto mr = mask; - auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m); - int ik = nk1 - k_step; - for (; ik >=0 && Mc[ik] != 0; ik -= k_step); - ik += k_step; - for (int k1 = 0; k1 < ik/k_step; ++k1) { + int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step); + for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);