mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
cuda : fix MLA flash-attn vec decode for asymmetric K/V head sizes (#2031)
The flash-attn vec kernels walk the KV cache in blocks of Dk rows for the score loop but accumulate V in blocks of Dv. With Dk == Dv that is the same thing, so normal attention shapes are fine. For absorbed MLA shapes where the K and V head sizes differ (Dk=576/Dv=512 and Dk=192/Dv=128) the two loops step a different number of KV rows, so K and V drift out of sync after the first block and the V pointer reads the wrong cache rows. This only shows up at decode (batch=1) on cards that fall back to the vec kernel for MLA, which on NVIDIA is pre-Volta. There deepseek2/GLM MLA models with -mla 1 -fa 1 or -mla 3 -fa 1 decode coherently for short prompts but collapse into garbage once n_kv passes the first KV block (Dk=576). Prefill/PPL is unaffected because prefill takes the tile kernel, not the vec kernel. Fix: the score loop already covers Dk KV rows, so the V loop and the V pointer step Dk rows too. For asymmetric Dk>Dv the V row is only Dv wide, so threads with tid >= Dv have no V element (their VKQ lane is discarded at the output store anyway) and read 0 instead of stepping past the row. The change keys off the compile-time Dk != Dv, so every symmetric instantiation compiles to byte-identical code and modern GPUs (which never take this vec path for MLA) are unaffected. Validated on a Tesla P100 (sm_60) with DeepSeek-V2-Lite Q4_K_M: decode coherence restored for -mla 1/3 -fa 1, KLD vs the -fa 0 soft_max path drops from 4.79 to 1.4e-4 (same top token 27% -> 100%) at c1024, and TG is unchanged (82.8 t/s). Co-authored-by: mb8565 <244351746+mb8565@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
d5507e33ae
commit
4553cd0059
@ -195,11 +195,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
// printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk);
|
||||
//}
|
||||
K += (first_y + blockIdx.y*Dk) * nb11;
|
||||
V += (first_y + blockIdx.y*Dv) * nb21;
|
||||
V += (first_y + blockIdx.y*Dk) * nb21;
|
||||
maskh += (first_y + blockIdx.y*Dk);
|
||||
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
|
||||
K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) {
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||
@ -278,15 +278,23 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA
|
||||
// head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and
|
||||
// the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60,
|
||||
// which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged.
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < Dv; k0 += 2) {
|
||||
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
for (int k0 = 0; k0 < Dk; k0 += 2) {
|
||||
if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 V_k;
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
||||
// For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element
|
||||
// (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row.
|
||||
half2 V_k = make_half2(0.0f, 0.0f);
|
||||
if (tid < Dv) {
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2];
|
||||
|
||||
@ -189,11 +189,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
|
||||
|
||||
K += (first_y + blockIdx.y*Dk) * nb11;
|
||||
V += (first_y + blockIdx.y*Dv) * nb21;
|
||||
V += (first_y + blockIdx.y*Dk) * nb21;
|
||||
maskh += (first_y + blockIdx.y*Dk);
|
||||
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
|
||||
K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
@ -266,13 +266,19 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA
|
||||
// head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and
|
||||
// the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60,
|
||||
// which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged.
|
||||
#pragma unroll
|
||||
for (int k = 0; k < Dv; ++k) {
|
||||
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) {
|
||||
for (int k = 0; k < Dk; ++k) {
|
||||
if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float V_ki = dequantize_1_v(V + k*nb21, tid);
|
||||
// For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element
|
||||
// (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row.
|
||||
const float V_ki = tid < Dv ? dequantize_1_v(V + k*nb21, tid) : 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*Dk + k];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user