mb8565 4553cd0059
cuda : fix MLA flash-attn vec decode for asymmetric K/V head sizes (#2031)
The flash-attn vec kernels walk the KV cache in blocks of Dk rows for the
score loop but accumulate V in blocks of Dv. With Dk == Dv that is the same
thing, so normal attention shapes are fine. For absorbed MLA shapes where the
K and V head sizes differ (Dk=576/Dv=512 and Dk=192/Dv=128) the two loops step
a different number of KV rows, so K and V drift out of sync after the first
block and the V pointer reads the wrong cache rows.

This only shows up at decode (batch=1) on cards that fall back to the vec
kernel for MLA, which on NVIDIA is pre-Volta. There deepseek2/GLM MLA models
with -mla 1 -fa 1 or -mla 3 -fa 1 decode coherently for short prompts but
collapse into garbage once n_kv passes the first KV block (Dk=576). Prefill/PPL
is unaffected because prefill takes the tile kernel, not the vec kernel.

Fix: the score loop already covers Dk KV rows, so the V loop and the V pointer
step Dk rows too. For asymmetric Dk>Dv the V row is only Dv wide, so threads
with tid >= Dv have no V element (their VKQ lane is discarded at the output
store anyway) and read 0 instead of stepping past the row.

The change keys off the compile-time Dk != Dv, so every symmetric instantiation
compiles to byte-identical code and modern GPUs (which never take this vec path
for MLA) are unaffected.

Validated on a Tesla P100 (sm_60) with DeepSeek-V2-Lite Q4_K_M: decode coherence
restored for -mla 1/3 -fa 1, KLD vs the -fa 0 soft_max path drops from 4.79 to
1.4e-4 (same top token 27% -> 100%) at c1024, and TG is unchanged (82.8 t/s).

Co-authored-by: mb8565 <244351746+mb8565@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-25 08:56:17 +02:00
..
2024-07-27 07:55:01 +02:00
2024-07-27 07:55:01 +02:00