mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
fa: preserve early-termination, fix multi-slot correctness via union of masks (#1880)
* fa: fix FlashQKV early-termination causing S=0 assertion with --parallel N>1 The backward-scan optimization in compute_helper/compute_helper_q checks only one mask position per k_step block on the last query row (q_step-1) to find where valid KV entries end. When q_step > 1 and different query rows have non-overlapping valid KV regions (multi-slot / --parallel N>1), the scan on the last row's mask can miss blocks that contain valid entries for earlier rows. This causes those rows to accumulate S=0, triggering the GGML_ASSERT(S > 0) in normalize_and_store_1row. Fix: remove the early-termination scan at all 4 sites and iterate all nk1/k_step blocks unconditionally. The mask already handles correctness: fully-masked blocks produce smax=-inf and skip V accumulation, so the performance cost is minimal for TG (small nq1) and acceptable for PP. Fixes #809 * fa: refactor multi-slot mask fix into mask_effective_nk1() helper Replace 4× inlined early-termination scans with a shared helper that computes the effective K boundary by scanning ALL query mask rows (union-of-masks). This is the minimal fix for multi-slot parallel inference where different slots have different sequence lengths. The helper returns the k_step-aligned boundary covering the longest active sequence across all rows, preserving single-slot performance (single row = same boundary as before). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Turbomen008 <Turbomen008@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
b4e1d916c5
commit
4fbd0c441b
@ -32,6 +32,24 @@
|
||||
|
||||
namespace {
|
||||
|
||||
// Compute effective K boundary by scanning ALL query rows (union-of-masks).
|
||||
// The original early-termination scanned only the last row, which is correct
|
||||
// for single-slot (all rows have the same mask) but wrong for multi-slot
|
||||
// parallel (--parallel N>1) where different slots have different sequence
|
||||
// lengths and therefore different mask patterns.
|
||||
// Returns the number of K elements to process (multiple of k_step).
|
||||
inline int mask_effective_nk1(const char * mask, int n_rows, int stride_m, int nk1, int k_step) {
|
||||
int ik_max = 0;
|
||||
for (int j = 0; j < n_rows; ++j) {
|
||||
auto Mc = (const uint16_t *)(mask + j * stride_m);
|
||||
int ik = nk1 - k_step;
|
||||
for (; ik >= 0 && Mc[ik] != 0; ik -= k_step);
|
||||
ik += k_step;
|
||||
if (ik > ik_max) ik_max = ik;
|
||||
}
|
||||
return ik_max;
|
||||
}
|
||||
|
||||
struct BaseHelper {
|
||||
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}
|
||||
|
||||
@ -1359,11 +1377,8 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
KQHelper::convert(q_step, stride_q, q, q_f16);
|
||||
#endif
|
||||
auto mr = mask;
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
int ik = nk1 - k_step;
|
||||
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||
ik += k_step;
|
||||
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
|
||||
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
|
||||
#ifdef __aarch64__
|
||||
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
|
||||
#else
|
||||
@ -1424,11 +1439,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
auto q8r = (typename HelperQ80R8<Dk>::block_q8 *)qptr;
|
||||
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
|
||||
auto mr = mask;
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
int ik = nk1 - k_step;
|
||||
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||
ik += k_step;
|
||||
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
|
||||
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
|
||||
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
@ -1455,11 +1467,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
perf.accum_nolock(0, t1);
|
||||
#endif
|
||||
auto mr = mask;
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
int ik = nk1 - k_step;
|
||||
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||
ik += k_step;
|
||||
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
|
||||
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
|
||||
#if FA_TIMING
|
||||
t1 = Perf::cur_time();
|
||||
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
|
||||
@ -1977,11 +1986,8 @@ struct FlashAttnBF16 {
|
||||
perf.accum_nolock(0, t1);
|
||||
#endif
|
||||
auto mr = mask;
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
int ik = nk1 - k_step;
|
||||
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||
ik += k_step;
|
||||
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||
int nk1_eff = mask_effective_nk1(mr, q_step, stride_m, nk1, k_step);
|
||||
for (int k1 = 0; k1 < nk1_eff/k_step; ++k1) {
|
||||
#if FA_TIMING
|
||||
//t1 = Perf::cur_time();
|
||||
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user