fa: preserve early-termination, fix multi-slot correctness via union of masks (#1880)

* fa: fix FlashQKV early-termination causing S=0 assertion with --parallel N>1

The backward-scan optimization in compute_helper/compute_helper_q checks
only one mask position per k_step block on the last query row (q_step-1)
to find where valid KV entries end. When q_step > 1 and different query
rows have non-overlapping valid KV regions (multi-slot / --parallel N>1),
the scan on the last row's mask can miss blocks that contain valid entries
for earlier rows. This causes those rows to accumulate S=0, triggering
the GGML_ASSERT(S > 0) in normalize_and_store_1row.

Fix: remove the early-termination scan at all 4 sites and iterate all
nk1/k_step blocks unconditionally. The mask already handles correctness:
fully-masked blocks produce smax=-inf and skip V accumulation, so the
performance cost is minimal for TG (small nq1) and acceptable for PP.

Fixes #809

* fa: refactor multi-slot mask fix into mask_effective_nk1() helper

Replace 4× inlined early-termination scans with a shared helper that
computes the effective K boundary by scanning ALL query mask rows
(union-of-masks). This is the minimal fix for multi-slot parallel
inference where different slots have different sequence lengths.

The helper returns the k_step-aligned boundary covering the longest
active sequence across all rows, preserving single-slot performance
(single row = same boundary as before).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Turbomen008 <Turbomen008@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Gearstickle 2026-05-26 06:16:49 -07:00 committed by GitHub
parent b4e1d916c5
commit 4fbd0c441b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,24 @@
namespace {
// Compute effective K boundary by scanning ALL query rows (union-of-masks).
// The original early-termination scanned only the last row, which is correct
// for single-slot (all rows have the same mask) but wrong for multi-slot
// parallel (--parallel N>1) where different slots have different sequence
// lengths and therefore different mask patterns.
// Returns the number of K elements to process (multiple of k_step).
inline int mask_effective_nk1(const char * mask, int n_rows, int stride_m, int nk1, int k_step) {
int ik_max = 0;
for (int j = 0; j < n_rows; ++j) {
auto Mc = (const uint16_t *)(mask + j * stride_m);
int ik = nk1 - k_step;
for (; ik >= 0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
if (ik > ik_max) ik_max = ik;
}
return ik_max;
}
struct BaseHelper {
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
@ -1359,11 +1377,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
KQHelper::convert(q_step, stride_q, q, q_f16);
#endif
auto mr = mask;
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
#ifdef __aarch64__
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
#else
@ -1424,11 +1439,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
auto q8r = (typename HelperQ80R8<Dk>::block_q8 *)qptr;
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
auto mr = mask;
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
fqkv.accumulate_qkv(vh, fms);
@ -1455,11 +1467,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
perf.accum_nolock(0, t1);
#endif
auto mr = mask;
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
#if FA_TIMING
t1 = Perf::cur_time();
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
@ -1977,11 +1986,8 @@ struct FlashAttnBF16 {
perf.accum_nolock(0, t1);
#endif
auto mr = mask;
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
#if FA_TIMING
//t1 = Perf::cur_time();
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);