diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 27e2ecb6..d96415b3 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -195,11 +195,11 @@ static __global__ void flash_attn_vec_ext_f16( // printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk); //} K += (first_y + blockIdx.y*Dk) * nb11; - V += (first_y + blockIdx.y*Dv) * nb21; + V += (first_y + blockIdx.y*Dk) * nb21; maskh += (first_y + blockIdx.y*Dk); for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, // Increment pointers after each loop: - K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { + K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) { // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, // see https://github.com/ggerganov/llama.cpp/pull/7061 . @@ -278,15 +278,23 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); + // Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA + // head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and + // the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60, + // which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged. #pragma unroll - for (int k0 = 0; k0 < Dv; k0 += 2) { - if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) { + for (int k0 = 0; k0 < Dk; k0 += 2) { + if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k0 >= ne11) { break; } - half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); + // For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element + // (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row. + half2 V_k = make_half2(0.0f, 0.0f); + if (tid < Dv) { + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); + } #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2]; diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index af8b4698..fd3b6fa4 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -189,11 +189,11 @@ static __global__ void flash_attn_vec_ext_f32( const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0; K += (first_y + blockIdx.y*Dk) * nb11; - V += (first_y + blockIdx.y*Dv) * nb21; + V += (first_y + blockIdx.y*Dk) * nb21; maskh += (first_y + blockIdx.y*Dk); for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, // Increment pointers after each loop: - K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { + K += gridDim.y*Dk*nb11, V += gridDim.y*Dk*nb21, maskh += gridDim.y*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: @@ -266,13 +266,19 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); + // Accumulate V over all Dk scored KV positions of this block (not Dv). For asymmetric MLA + // head sizes (Dk=576 K, Dv=512 V) the K/score loop above covers Dk KV rows, so the V loop and + // the V pointer stride must also step Dk KV rows or K and V desync (= garbage decode on sm_60, + // which uses this vec kernel for MLA -fa 1 batch=1). Dk==Dv leaves every other case unchanged. #pragma unroll - for (int k = 0; k < Dv; ++k) { - if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) { + for (int k = 0; k < Dk; ++k) { + if (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + k >= ne11) { break; } - const float V_ki = dequantize_1_v(V + k*nb21, tid); + // For asymmetric Dk>Dv the V row is only Dv wide, so threads tid>=Dv have no V element + // (their VKQ lane is discarded at output anyway). Read 0 to avoid stepping past the row. + const float V_ki = tid < Dv ? dequantize_1_v(V + k*nb21, tid) : 0.0f; #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*Dk + k];