cuda : fix MLA flash-attn vec decode for asymmetric K/V head sizes (#2031)

The flash-attn vec kernels walk the KV cache in blocks of Dk rows for the
score loop but accumulate V in blocks of Dv. With Dk == Dv that is the same
thing, so normal attention shapes are fine. For absorbed MLA shapes where the
K and V head sizes differ (Dk=576/Dv=512 and Dk=192/Dv=128) the two loops step
a different number of KV rows, so K and V drift out of sync after the first
block and the V pointer reads the wrong cache rows.

This only shows up at decode (batch=1) on cards that fall back to the vec
kernel for MLA, which on NVIDIA is pre-Volta. There deepseek2/GLM MLA models
with -mla 1 -fa 1 or -mla 3 -fa 1 decode coherently for short prompts but
collapse into garbage once n_kv passes the first KV block (Dk=576). Prefill/PPL
is unaffected because prefill takes the tile kernel, not the vec kernel.

Fix: the score loop already covers Dk KV rows, so the V loop and the V pointer
step Dk rows too. For asymmetric Dk>Dv the V row is only Dv wide, so threads
with tid >= Dv have no V element (their VKQ lane is discarded at the output
store anyway) and read 0 instead of stepping past the row.

The change keys off the compile-time Dk != Dv, so every symmetric instantiation
compiles to byte-identical code and modern GPUs (which never take this vec path
for MLA) are unaffected.

Validated on a Tesla P100 (sm_60) with DeepSeek-V2-Lite Q4_K_M: decode coherence
restored for -mla 1/3 -fa 1, KLD vs the -fa 0 soft_max path drops from 4.79 to
1.4e-4 (same top token 27% -> 100%) at c1024, and TG is unchanged (82.8 t/s).

Co-authored-by: mb8565 <244351746+mb8565@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
mb8565 2026-06-25 01:56:17 -05:00 committed by GitHub
parent d5507e33ae
commit 4553cd0059
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 12 deletions

View File

@ -195,11 +195,11 @@ static __global__ void flash_attn_vec_ext_f16(
// printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk);
//}
K += (first_y + blockIdx.y*Dk) * nb11;
V += (first_y + blockIdx.y*Dv) * nb21;
V += (first_y + blockIdx.y*Dk) * nb21;
maskh += (first_y + blockIdx.y*Dk);
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
// Increment pointers after each loop:
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) {
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
@ -278,15 +278,23 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
// Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA
// head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and
// the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60,
// which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged.
#pragma unroll
for (int k0 = 0; k0 < Dv; k0 += 2) {
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) {
for (int k0 = 0; k0 < Dk; k0 += 2) {
if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k0 >= ne11) {
break;
}
half2 V_k;
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
// For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element
// (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row.
half2 V_k = make_half2(0.0f, 0.0f);
if (tid < Dv) {
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2];

View File

@ -189,11 +189,11 @@ static __global__ void flash_attn_vec_ext_f32(
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
K += (first_y + blockIdx.y*Dk) * nb11;
V += (first_y + blockIdx.y*Dv) * nb21;
V += (first_y + blockIdx.y*Dk) * nb21;
maskh += (first_y + blockIdx.y*Dk);
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
// Increment pointers after each loop:
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) {
// Calculate KQ tile and keep track of new maximum KQ values:
@ -266,13 +266,19 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
// Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA
// head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and
// the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60,
// which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged.
#pragma unroll
for (int k = 0; k < Dv; ++k) {
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) {
for (int k = 0; k < Dk; ++k) {
if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k >= ne11) {
break;
}
const float V_ki = dequantize_1_v(V + k*nb21, tid);
// For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element
// (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row.
const float V_ki = tid < Dv ? dequantize_1_v(V + k*nb21, tid) : 0.0f;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*Dk + k];