From 4553cd0059acdda093500b7eb809d9b696890667 Mon Sep 17 00:00:00 2001 From: mb8565 Date: Thu, 25 Jun 2026 01:56:17 -0500 Subject: [PATCH] cuda : fix MLA flash-attn vec decode for asymmetric K/V head sizes (#2031) The flash-attn vec kernels walk the KV cache in blocks of Dk rows for the score loop but accumulate V in blocks of Dv. With Dk == Dv that is the same thing, so normal attention shapes are fine. For absorbed MLA shapes where the K and V head sizes differ (Dk=576/Dv=512 and Dk=192/Dv=128) the two loops step a different number of KV rows, so K and V drift out of sync after the first block and the V pointer reads the wrong cache rows. This only shows up at decode (batch=1) on cards that fall back to the vec kernel for MLA, which on NVIDIA is pre-Volta. There deepseek2/GLM MLA models with -mla 1 -fa 1 or -mla 3 -fa 1 decode coherently for short prompts but collapse into garbage once n_kv passes the first KV block (Dk=576). Prefill/PPL is unaffected because prefill takes the tile kernel, not the vec kernel. Fix: the score loop already covers Dk KV rows, so the V loop and the V pointer step Dk rows too. For asymmetric Dk>Dv the V row is only Dv wide, so threads with tid >= Dv have no V element (their VKQ lane is discarded at the output store anyway) and read 0 instead of stepping past the row. The change keys off the compile-time Dk != Dv, so every symmetric instantiation compiles to byte-identical code and modern GPUs (which never take this vec path for MLA) are unaffected. Validated on a Tesla P100 (sm_60) with DeepSeek-V2-Lite Q4_K_M: decode coherence restored for -mla 1/3 -fa 1, KLD vs the -fa 0 soft_max path drops from 4.79 to 1.4e-4 (same top token 27% -> 100%) at c1024, and TG is unchanged (82.8 t/s). Co-authored-by: mb8565 <244351746+mb8565@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 --- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 22 +++++++++++++++------- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 16 +++++++++++----- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 27e2ecb6..d96415b3 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -195,11 +195,11 @@ static __global__ void flash_attn_vec_ext_f16( // printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk); //} K += (first_y + blockIdx.y*Dk) * nb11; - V += (first_y + blockIdx.y*Dv) * nb21; + V += (first_y + blockIdx.y*Dk) * nb21; maskh += (first_y + blockIdx.y*Dk); for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, // Increment pointers after each loop: - K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { + K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) { // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, // see https://github.com/ggerganov/llama.cpp/pull/7061 . @@ -278,15 +278,23 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); + // Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA + // head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and + // the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60, + // which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged. #pragma unroll - for (int k0 = 0; k0 < Dv; k0 += 2) { - if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) { + for (int k0 = 0; k0 < Dk; k0 += 2) { + if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k0 >= ne11) { break; } - half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); + // For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element + // (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row. + half2 V_k = make_half2(0.0f, 0.0f); + if (tid < Dv) { + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); + } #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2]; diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index af8b4698..fd3b6fa4 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -189,11 +189,11 @@ static __global__ void flash_attn_vec_ext_f32( const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0; K += (first_y + blockIdx.y*Dk) * nb11; - V += (first_y + blockIdx.y*Dv) * nb21; + V += (first_y + blockIdx.y*Dk) * nb21; maskh += (first_y + blockIdx.y*Dk); for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, // Increment pointers after each loop: - K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { + K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: @@ -266,13 +266,19 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); + // Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA + // head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and + // the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60, + // which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged. #pragma unroll - for (int k = 0; k < Dv; ++k) { - if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) { + for (int k = 0; k < Dk; ++k) { + if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k >= ne11) { break; } - const float V_ki = dequantize_1_v(V + k*nb21, tid); + // For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element + // (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row. + const float V_ki = tid < Dv ? dequantize_1_v(V + k*nb21, tid) : 0.0f; #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*Dk + k];