diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 91fb07c93e..3192130ccf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -463,6 +463,7 @@ void main() { } rowmaxf = max(rowmaxf, float(Sf[r][c])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // M = max(rowmax, Mold) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 23ae3833e5..16178e5770 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -352,6 +352,7 @@ void main() { } rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // Compute max across the row