diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c9f906d793..2a30fb95c6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -691,6 +691,7 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_bf16_support {}; bool coopmat2_decode_vector; bool pipeline_executable_properties_support {}; @@ -3139,7 +3140,7 @@ struct vk_fa_tuning_params { }; static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { @@ -3279,6 +3280,13 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) { + path = FA_COOPMAT1; + } + if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) { + path = FA_SCALAR; + } + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 path = FA_SCALAR; @@ -3288,7 +3296,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type); if (!shape_ok || !shmem_ok) { path = FA_SCALAR; @@ -3334,8 +3342,8 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { const auto fa_block_bytes = [](ggml_type t) -> uint32_t { - // decodeBufF32 uses a block of vec4s for a better memory access pattern. - return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + if (t == GGML_TYPE_F32) return 16u; + return (uint32_t) ggml_type_size(t); }; return { /* 0 WorkGroupSize */ state.workgroup_size, @@ -3849,10 +3857,16 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); const void * spv_data = nullptr; size_t spv_size = 0; - if (use_mmq) { + const char *name = nullptr; + if (bf16_kv) { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16"; + } else if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } @@ -3862,6 +3876,7 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_size = flash_attn_f32_f16_fp32_int8_len; } #endif + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } else { if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } @@ -3870,8 +3885,8 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_data = flash_attn_f32_f16_fp32_data; spv_size = flash_attn_f32_f16_fp32_len; } + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } - const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3889,11 +3904,25 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + const void * spv_data; size_t spv_size; - if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } - else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } - const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + const char *name; + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm1_data; + spv_size = flash_attn_f32_f16_bf16_cm1_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1"; +#else + continue; +#endif + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + } ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3911,10 +3940,20 @@ static void ggml_vk_load_shaders(vk_device& device) { const bool aligned = fa.first.aligned; const bool f32acc = fa.first.f32acc; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const void * spv_data; size_t spv_size; const char * name; - if (aligned) { + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat2_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm2_data; + spv_size = flash_attn_f32_f16_bf16_cm2_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2"; +#else + continue; +#endif + } else if (aligned) { if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } } else { @@ -5784,46 +5823,72 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp16_256 = false, found_fp32_128 = false, found_fp32_256 = false; + bool found_bf16_128 = false, + found_bf16_256 = false; // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 // with 32x16x16 and 256 with 32x32x16. for (auto &prop : flexible_dimensions) { if (prop.saturatingAccumulation == VK_FALSE && - prop.scope == VK_SCOPE_WORKGROUP_KHR && - prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + prop.scope == VK_SCOPE_WORKGROUP_KHR) { - if (prop.workgroupInvocations == 128 && - prop.MGranularity <= 32 && - prop.NGranularity <= 16 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_128 = true; + if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_128 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } } } - if (prop.workgroupInvocations == 256 && - prop.MGranularity <= 32 && - prop.NGranularity <= 32 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_256 = true; + +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + found_bf16_128 = true; } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_256 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + found_bf16_256 = true; } } +#endif } } if (found_fp16_128 && found_fp16_256 && found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256; device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } @@ -9448,7 +9513,8 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + // BF16 uses the fp32 shader (FLOAT_TYPE=float) + const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float); const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); @@ -9489,7 +9555,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) { // Needs to be kept up to date on shader changes const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; @@ -9519,8 +9585,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t vsh_stride = MatBc / 4 * row_split; const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + // BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes) + const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4; const uint32_t osh_stride = params.row_split * MatBr / 4; - const uint32_t pvsh = MatBc * osh_stride * f16vec4; + const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size; const uint32_t slope = Br * acctype; @@ -9589,7 +9657,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -16400,6 +16468,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (t) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_0: @@ -16415,6 +16484,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } + if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) { + return false; + } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll return false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9a7957da97..66dcf61021 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -97,8 +97,17 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define FA_TYPE_Q5_0 6u #define FA_TYPE_Q5_1 7u #define FA_TYPE_Q8_0 8u +#define FA_TYPE_BF16 30u #define FA_TYPE_Q1_0 41u +#if defined(BFLOAT16) +#define O_TYPE float +#define O_TYPEV4 vec4 +#else +#define O_TYPE FLOAT_TYPE +#define O_TYPEV4 FLOAT_TYPEV4 +#endif + // Number of matrix elements per buffer block, derived from the K/V type spec // constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 // and bypasses the dequant path entirely. Quants follow their ggml block sizes. @@ -111,6 +120,7 @@ uint fa_block_elems(uint ty) { case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_BF16: return 1u; case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere default: return 1u; } @@ -248,7 +258,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV / 4 + c; data_ov4[o_offset + offset] = D_TYPEV4(elems); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index bffcc095be..23ae3833e5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -6,6 +6,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -14,7 +18,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "flash_attn_dequant.glsl" +#endif // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -27,32 +33,32 @@ const uint32_t cols_per_thread = Bc / cols_per_iter; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; -layout (binding = 1) readonly buffer K {float16_t data_k[];}; -layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; -layout (binding = 2) readonly buffer V {float16_t data_v[];}; -layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];}; +layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];}; +layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];}; +layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; shared float tmpsh[row_split]; -const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 Qf[Br * qstride]; +const uint32_t qstride = HSK_pad / 4 + 2; +shared FLOAT_TYPEV4 Qf[Br * qstride]; const uint psh_stride = Br / 4 + 2; -shared f16vec4 Psh[Bc * psh_stride]; +shared FLOAT_TYPEV4 Psh[Bc * psh_stride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; -const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; +shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; const uint32_t osh_stride = row_split * MatBr / 4; -shared f16vec4 pvsh[MatBc * osh_stride]; +shared O_TYPEV4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -76,7 +82,7 @@ void main() { if ((HSK % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { - Qf[i + tid] = f16vec4(0); + Qf[i + tid] = FLOAT_TYPEV4(0); } } barrier(); @@ -89,15 +95,15 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - f16vec4 Of[rows_per_thread][d_per_thread]; + O_TYPEV4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = f16vec4(0.0); + Of[r][d] = O_TYPEV4(0.0); } } @@ -222,15 +228,18 @@ void main() { uint32_t d = (idx + tid) % (HSK_pad / 4); uint32_t c = (idx + tid) / (HSK_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); } } @@ -244,16 +253,16 @@ void main() { // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat SfMat = coopmat(0); - coopmat KMat; - coopmat QMat; + coopmat KMat; + coopmat QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem - // If not, f16 K is loaded directly from global memory if aligned, otherwise + // If not, K is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If K is not type f16, then it is always staged for dequantization. + // If K is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always need to dequant into kvsh; for f16 we can load + // For quants we always need to dequant into kvsh; for f16/bf16 we can load // directly from global memory when alignment / bounds allow it. const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; if (stage_k) { @@ -262,15 +271,18 @@ void main() { uint32_t col_vec = (idx + tid) % (MatBr / 4); uint32_t row = (idx + tid) / (MatBr / 4); if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); } } @@ -357,7 +369,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local]; } } @@ -368,10 +380,10 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { const uint row = tile_row(r); if (KV_bounds_check && j * Bc + col >= KV) { - Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f); } else { const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); - const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { Lf[r + vec_idx] += Pf[vec_idx]; } @@ -385,15 +397,18 @@ void main() { uint32_t d = (idx + tid) % (HSV_pad / 4); uint32_t c = (idx + tid) / (HSV_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { - f16vec4 V_Tf = f16vec4(0); + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; uint ib = coord / BLOCK_SIZE_V; uint iqs = (coord % BLOCK_SIZE_V); V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } else +#endif + { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); } } @@ -409,7 +424,7 @@ void main() { [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - coopmat PVMat = coopmat(0); + coopmat PVMat = coopmat(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; @@ -417,11 +432,11 @@ void main() { const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. - // If not, f16 V is loaded directly from global memory if aligned, otherwise + // If not, V is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If V is not type f16, then it is always staged for dequantization. + // If V is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always preload via kvsh. For f16 we only preload when + // For quants we always preload via kvsh. For f16/bf16 we only preload when // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). const bool stage_v = USE_DECODE_V || KV_bounds_check; if (stage_v) { @@ -438,13 +453,16 @@ void main() { const uint iqs = coord % BLOCK_SIZE_V; if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { + } else +#endif + { kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; } } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f); } } } @@ -459,7 +477,7 @@ void main() { if (SHMEM_STAGING == 0) { if (!USE_DECODE_V && !KV_bounds_check) { - // F16 values can be loaded directly from global memory + // F16/BF16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); @@ -573,7 +591,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= float16_t(ms); + Of[r][d_local] *= O_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -591,7 +609,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= float16_t(Lfrcp[r]); + Of[r][d_local] *= O_TYPE(Lfrcp[r]); #if defined(FLOAT_TYPE_MAX) Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6d45b4931d..b9c03fe499 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -8,6 +8,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable @@ -21,7 +25,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "dequant_funcs_cm2.glsl" +#endif // buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { @@ -31,6 +37,7 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; +#if !defined(BFLOAT16) float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); @@ -91,6 +98,7 @@ f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], co #define FADECODEK , faDecodeK #define FADECODEV , faDecodeV #endif +#endif layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -195,15 +203,15 @@ void main() { tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); coopmat Q; - coopmat Qf16; + coopmat Qf16; uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat(Q); - Qf16 *= float16_t(p.scale); + Q *= Q_TYPE(p.scale); + Qf16 = coopmat(Q); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -291,16 +299,20 @@ void main() { coopmat S = coopmat(0); - coopmat K_T; + coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. +#if defined(BFLOAT16) + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); +#else const bool k_use_decode = (bs_k > 1u); if (k_use_decode) { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); } else { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); } +#endif S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -351,22 +363,26 @@ void main() { coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); } - coopmat P_A = coopmat(P); + coopmat P_A = coopmat(P); // compute rowsum by multiplying by matrix of all ones. - coopmat One = coopmat(1.0); + coopmat One = coopmat(1.0); rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat V; + coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; +#if defined(BFLOAT16) + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); +#else const bool v_use_decode = (bs_v > 1u); if (v_use_decode) { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); } else { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); } +#endif L = eM*L + rowsum; @@ -378,7 +394,7 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O *= coopmat(eMdiag); + O *= coopmat(eMdiag); O = coopMatMulAdd(P_A, V, O); } @@ -427,7 +443,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= float16_t(ms); + O[i] *= O_TYPE(ms); } else { vs = exp(sink - Mr[i]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl index 02106f33cb..8704479d96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -28,6 +28,9 @@ layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; +layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16; +layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16; + // Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 // views, used by the MMQ K-side hot path for fast 4-uint loads. layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; @@ -99,6 +102,9 @@ layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 dat return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ } +#define FA_DEQUANT4_BF16(BUF) \ + return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4]))); + FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { switch (FaTypeK) { @@ -108,6 +114,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16) } } else { switch (FaTypeV) { @@ -117,6 +124,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16) } } return FLOAT_TYPEV4(0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fa9b938e4f..de7dbec2c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -662,6 +662,28 @@ void process_shaders() { } } + const std::map fa_bf16_dict = { + {"FLOAT_TYPE", "bfloat16_t"}, + {"FLOAT_TYPEV2", "bf16vec2"}, + {"FLOAT_TYPEV4", "bf16vec4"}, + {"ACC_TYPE", "float"}, + {"ACC_TYPEV2", "vec2"}, + {"ACC_TYPEV4", "vec4"}, + {"BFLOAT16", "1"}, + }; + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), + true, true, false, false); +#endif + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), + true, false, true, false); +#endif + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) {