vulkan: add Flash Attention support for BFloat16 KV cache (#23420)

* vulkan: add flash attention bf16 kv support

* vulkan: bf16 FA coopmat1 support

* vulkan: bf16 FA coopmat2 support

* fix FA bf16 f32 fallback

* fix FA bf16 coopmat1 shader

* fix FA bf16 coopmat2 shader

* code cleanup

* cleanup comment change

* address feedback

* add O_TYPE for cm2 FA

* use O_TYPE for gqaStore function

* reduce BFLOAT16 ifdefs
This commit is contained in:
Ruben Ortlam 2026-05-30 10:39:31 +02:00 committed by GitHub
parent 337528571d
commit 6e093b80ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 234 additions and 88 deletions

View File

@ -691,6 +691,7 @@ struct vk_device_struct {
uint32_t coopmat_int_k;
bool coopmat2;
bool coopmat2_bf16_support {};
bool coopmat2_decode_vector;
bool pipeline_executable_properties_support {};
@ -3139,7 +3140,7 @@ struct vk_fa_tuning_params {
};
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16);
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
@ -3279,6 +3280,13 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) {
path = FA_COOPMAT1;
}
if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) {
path = FA_SCALAR;
}
if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
// Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
path = FA_SCALAR;
@ -3288,7 +3296,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type);
if (!shape_ok || !shmem_ok) {
path = FA_SCALAR;
@ -3334,8 +3342,8 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
if (t == GGML_TYPE_F32) return 16u;
return (uint32_t) ggml_type_size(t);
};
return {
/* 0 WorkGroupSize */ state.workgroup_size,
@ -3849,10 +3857,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t fa_sgs = fa.first.subgroup_size;
const bool fa_ds = fa.first.subgroup_size == 0;
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
const void * spv_data = nullptr;
size_t spv_size = 0;
if (use_mmq) {
const char *name = nullptr;
if (bf16_kv) {
spv_data = flash_attn_f32_f16_fp32_data;
spv_size = flash_attn_f32_f16_fp32_len;
name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16";
} else if (use_mmq) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
@ -3862,6 +3876,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
spv_size = flash_attn_f32_f16_fp32_int8_len;
}
#endif
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
} else {
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
@ -3870,8 +3885,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
spv_data = flash_attn_f32_f16_fp32_data;
spv_size = flash_attn_f32_f16_fp32_len;
}
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
}
const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
@ -3889,11 +3904,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t fa_sgs = fa.first.subgroup_size;
const bool fa_ds = fa.first.subgroup_size == 0;
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
const void * spv_data;
size_t spv_size;
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
const char *name;
if (bf16_kv) {
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (!device->coopmat_bf16_support) continue;
spv_data = flash_attn_f32_f16_bf16_cm1_data;
spv_size = flash_attn_f32_f16_bf16_cm1_len;
name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1";
#else
continue;
#endif
} else {
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
}
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
@ -3911,10 +3940,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
const bool aligned = fa.first.aligned;
const bool f32acc = fa.first.f32acc;
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
const void * spv_data;
size_t spv_size;
const char * name;
if (aligned) {
if (bf16_kv) {
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (!device->coopmat2_bf16_support) continue;
spv_data = flash_attn_f32_f16_bf16_cm2_data;
spv_size = flash_attn_f32_f16_bf16_cm2_len;
name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2";
#else
continue;
#endif
} else if (aligned) {
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
} else {
@ -5784,46 +5823,72 @@ static vk_device ggml_vk_get_device(size_t idx) {
found_fp16_256 = false,
found_fp32_128 = false,
found_fp32_256 = false;
bool found_bf16_128 = false,
found_bf16_256 = false;
// need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
// with 32x16x16 and 256 with 32x32x16.
for (auto &prop : flexible_dimensions) {
if (prop.saturatingAccumulation == VK_FALSE &&
prop.scope == VK_SCOPE_WORKGROUP_KHR &&
prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
prop.scope == VK_SCOPE_WORKGROUP_KHR) {
if (prop.workgroupInvocations == 128 &&
prop.MGranularity <= 32 &&
prop.NGranularity <= 16 &&
prop.KGranularity <= 16) {
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
found_fp16_128 = true;
if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
if (prop.workgroupInvocations == 128 &&
prop.MGranularity <= 32 &&
prop.NGranularity <= 16 &&
prop.KGranularity <= 16) {
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
found_fp16_128 = true;
}
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
found_fp32_128 = true;
}
}
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
found_fp32_128 = true;
if (prop.workgroupInvocations == 256 &&
prop.MGranularity <= 32 &&
prop.NGranularity <= 32 &&
prop.KGranularity <= 16) {
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
found_fp16_256 = true;
}
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
found_fp32_256 = true;
}
}
}
if (prop.workgroupInvocations == 256 &&
prop.MGranularity <= 32 &&
prop.NGranularity <= 32 &&
prop.KGranularity <= 16) {
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
found_fp16_256 = true;
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
if (prop.workgroupInvocations == 128 &&
prop.MGranularity <= 32 &&
prop.NGranularity <= 16 &&
prop.KGranularity <= 16) {
found_bf16_128 = true;
}
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
found_fp32_256 = true;
if (prop.workgroupInvocations == 256 &&
prop.MGranularity <= 32 &&
prop.NGranularity <= 32 &&
prop.KGranularity <= 16) {
found_bf16_256 = true;
}
}
#endif
}
}
if (found_fp16_128 && found_fp16_256 &&
found_fp32_128 && found_fp32_256 &&
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
device->coopmat2 = true;
device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256;
device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
}
}
@ -9448,7 +9513,8 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t Br = params.block_rows;
const uint32_t Bc = params.block_cols;
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
// BF16 uses the fp32 shader (FLOAT_TYPE=float)
const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float);
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
@ -9489,7 +9555,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
return supported;
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) {
// Needs to be kept up to date on shader changes
const uint32_t Br = params.block_rows;
const uint32_t Bc = params.block_cols;
@ -9519,8 +9585,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t vsh_stride = MatBc / 4 * row_split;
const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
// BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes)
const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4;
const uint32_t osh_stride = params.row_split * MatBr / 4;
const uint32_t pvsh = MatBc * osh_stride * f16vec4;
const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size;
const uint32_t slope = Br * acctype;
@ -9589,7 +9657,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16;
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
@ -16400,6 +16468,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
switch (t) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q5_0:
@ -16415,6 +16484,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
return false;
}
if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) {
return false;
}
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
// scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
return false;

View File

@ -97,8 +97,17 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
#define FA_TYPE_Q5_0 6u
#define FA_TYPE_Q5_1 7u
#define FA_TYPE_Q8_0 8u
#define FA_TYPE_BF16 30u
#define FA_TYPE_Q1_0 41u
#if defined(BFLOAT16)
#define O_TYPE float
#define O_TYPEV4 vec4
#else
#define O_TYPE FLOAT_TYPE
#define O_TYPEV4 FLOAT_TYPEV4
#endif
// Number of matrix elements per buffer block, derived from the K/V type spec
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
@ -111,6 +120,7 @@ uint fa_block_elems(uint ty) {
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
case FA_TYPE_BF16: return 1u;
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
default: return 1u;
}
@ -248,7 +258,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV / 4 + c;
data_ov4[o_offset + offset] = D_TYPEV4(elems);

View File

@ -6,6 +6,10 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#if defined(BFLOAT16)
#extension GL_EXT_bfloat16 : enable
#endif
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_vote : enable
@ -14,7 +18,9 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
#if !defined(BFLOAT16)
#include "flash_attn_dequant.glsl"
#endif
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
@ -27,32 +33,32 @@ const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];};
layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];};
layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];};
layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
const uint32_t qstride = HSK_pad / 4 + 2;
shared FLOAT_TYPEV4 Qf[Br * qstride];
const uint psh_stride = Br / 4 + 2;
shared f16vec4 Psh[Bc * psh_stride];
shared FLOAT_TYPEV4 Psh[Bc * psh_stride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2;
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
const uint vsh_stride = v_cols;
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
const uint32_t osh_stride = row_split * MatBr / 4;
shared f16vec4 pvsh[MatBc * osh_stride];
shared O_TYPEV4 pvsh[MatBc * osh_stride];
shared ACC_TYPE slope[Br];
@ -76,7 +82,7 @@ void main() {
if ((HSK % 16) != 0) {
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
if (i + tid < Br * qstride) {
Qf[i + tid] = f16vec4(0);
Qf[i + tid] = FLOAT_TYPEV4(0);
}
}
barrier();
@ -89,15 +95,15 @@ void main() {
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
f16vec4 Of[rows_per_thread][d_per_thread];
O_TYPEV4 Of[rows_per_thread][d_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
Of[r][d] = f16vec4(0.0);
Of[r][d] = O_TYPEV4(0.0);
}
}
@ -222,15 +228,18 @@ void main() {
uint32_t d = (idx + tid) % (HSK_pad / 4);
uint32_t c = (idx + tid) / (HSK_pad / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
f16vec4 K_Tf = f16vec4(0);
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
#if !defined(BFLOAT16)
if (USE_DECODE_K) {
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
} else
#endif
{
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
}
}
@ -244,16 +253,16 @@ void main() {
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
// If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
// If not, f16 K is loaded directly from global memory if aligned, otherwise
// If not, K is loaded directly from global memory if aligned, otherwise
// staged through a Bc * MatBr size staging buffer.
// If K is not type f16, then it is always staged for dequantization.
// If K is a quant type, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
// For quants we always need to dequant into kvsh; for f16 we can load
// For quants we always need to dequant into kvsh; for f16/bf16 we can load
// directly from global memory when alignment / bounds allow it.
const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK;
if (stage_k) {
@ -262,15 +271,18 @@ void main() {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
#if !defined(BFLOAT16)
if (USE_DECODE_K) {
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE_K;
uint iqs = (coord % BLOCK_SIZE_K);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
} else {
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
} else
#endif
{
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
}
}
@ -357,7 +369,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local];
}
}
@ -368,10 +380,10 @@ void main() {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
const uint row = tile_row(r);
if (KV_bounds_check && j * Bc + col >= KV) {
Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f);
} else {
const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
[[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
Lf[r + vec_idx] += Pf[vec_idx];
}
@ -385,15 +397,18 @@ void main() {
uint32_t d = (idx + tid) % (HSV_pad / 4);
uint32_t c = (idx + tid) / (HSV_pad / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
f16vec4 V_Tf = f16vec4(0);
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
#if !defined(BFLOAT16)
if (USE_DECODE_V) {
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
uint ib = coord / BLOCK_SIZE_V;
uint iqs = (coord % BLOCK_SIZE_V);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
} else
#endif
{
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
}
}
@ -409,7 +424,7 @@ void main() {
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
// Preload V tiles for [Bc, 16 * num subgroups]
const uint v_rows = Bc;
@ -417,11 +432,11 @@ void main() {
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
// If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
// If not, f16 V is loaded directly from global memory if aligned, otherwise
// If not, V is loaded directly from global memory if aligned, otherwise
// staged through a Bc * MatBr size staging buffer.
// If V is not type f16, then it is always staged for dequantization.
// If V is a quant type, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
// For quants we always preload via kvsh. For f16 we only preload when
// For quants we always preload via kvsh. For f16/bf16 we only preload when
// alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4).
const bool stage_v = USE_DECODE_V || KV_bounds_check;
if (stage_v) {
@ -438,13 +453,16 @@ void main() {
const uint iqs = coord % BLOCK_SIZE_V;
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if !defined(BFLOAT16)
if (USE_DECODE_V) {
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
} else {
} else
#endif
{
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
}
} else {
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f);
}
}
}
@ -459,7 +477,7 @@ void main() {
if (SHMEM_STAGING == 0) {
if (!USE_DECODE_V && !KV_bounds_check) {
// F16 values can be loaded directly from global memory
// F16/BF16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
@ -573,7 +591,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
Of[r][d_local] *= float16_t(ms);
Of[r][d_local] *= O_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
@ -591,7 +609,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d_local] *= float16_t(Lfrcp[r]);
Of[r][d_local] *= O_TYPE(Lfrcp[r]);
#if defined(FLOAT_TYPE_MAX)
Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
#endif

View File

@ -8,6 +8,10 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#if defined(BFLOAT16)
#extension GL_EXT_bfloat16 : enable
#endif
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
@ -21,7 +25,9 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
#if !defined(BFLOAT16)
#include "dequant_funcs_cm2.glsl"
#endif
// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V.
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K {
@ -31,6 +37,7 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_
uint8_t raw[FaBlockBytesV];
};
#if !defined(BFLOAT16)
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeK) {
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
@ -91,6 +98,7 @@ f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], co
#define FADECODEK , faDecodeK
#define FADECODEV , faDecodeV
#endif
#endif
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
@ -195,15 +203,15 @@ void main() {
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
Q *= Q_TYPE(p.scale);
Qf16 = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
@ -291,16 +299,20 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
#if defined(BFLOAT16)
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
#else
const bool k_use_decode = (bs_k > 1u);
if (k_use_decode) {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK);
} else {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
}
#endif
S = coopMatMulAdd(Qf16, K_T, S);
if (LOGIT_SOFTCAP) {
@ -351,22 +363,26 @@ void main() {
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
}
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
// compute rowsum by multiplying by matrix of all ones.
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
#if defined(BFLOAT16)
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
#else
const bool v_use_decode = (bs_v > 1u);
if (v_use_decode) {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV);
} else {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
}
#endif
L = eM*L + rowsum;
@ -378,7 +394,7 @@ void main() {
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
O *= coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
O = coopMatMulAdd(P_A, V, O);
}
@ -427,7 +443,7 @@ void main() {
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);
O[i] *= float16_t(ms);
O[i] *= O_TYPE(ms);
} else {
vs = exp(sink - Mr[i]);
}

View File

@ -28,6 +28,9 @@ layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[];
layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0;
layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0;
layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16;
layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16;
// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16
// views, used by the MMQ K-side hot path for fast 4-uint loads.
layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32;
@ -99,6 +102,9 @@ layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 dat
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \
}
#define FA_DEQUANT4_BF16(BUF) \
return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4])));
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
switch (FaTypeK) {
@ -108,6 +114,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0)
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1)
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0)
case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16)
}
} else {
switch (FaTypeV) {
@ -117,6 +124,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0)
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1)
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0)
case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16)
}
}
return FLOAT_TYPEV4(0);

View File

@ -662,6 +662,28 @@ void process_shaders() {
}
}
const std::map<std::string, std::string> fa_bf16_dict = {
{"FLOAT_TYPE", "bfloat16_t"},
{"FLOAT_TYPEV2", "bf16vec2"},
{"FLOAT_TYPEV4", "bf16vec4"},
{"ACC_TYPE", "float"},
{"ACC_TYPEV2", "vec2"},
{"ACC_TYPEV4", "vec4"},
{"BFLOAT16", "1"},
};
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp",
merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}),
true, true, false, false);
#endif
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp",
merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}),
true, false, true, false);
#endif
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}};
for (const auto& tname : type_names) {