mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
vulkan: add Flash Attention support for BFloat16 KV cache (#23420)
* vulkan: add flash attention bf16 kv support * vulkan: bf16 FA coopmat1 support * vulkan: bf16 FA coopmat2 support * fix FA bf16 f32 fallback * fix FA bf16 coopmat1 shader * fix FA bf16 coopmat2 shader * code cleanup * cleanup comment change * address feedback * add O_TYPE for cm2 FA * use O_TYPE for gqaStore function * reduce BFLOAT16 ifdefs
This commit is contained in:
parent
337528571d
commit
6e093b80ea
@ -691,6 +691,7 @@ struct vk_device_struct {
|
||||
uint32_t coopmat_int_k;
|
||||
|
||||
bool coopmat2;
|
||||
bool coopmat2_bf16_support {};
|
||||
bool coopmat2_decode_vector;
|
||||
|
||||
bool pipeline_executable_properties_support {};
|
||||
@ -3139,7 +3140,7 @@ struct vk_fa_tuning_params {
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
|
||||
@ -3279,6 +3280,13 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
||||
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) {
|
||||
path = FA_COOPMAT1;
|
||||
}
|
||||
if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
|
||||
// Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
|
||||
path = FA_SCALAR;
|
||||
@ -3288,7 +3296,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
||||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type);
|
||||
|
||||
if (!shape_ok || !shmem_ok) {
|
||||
path = FA_SCALAR;
|
||||
@ -3334,8 +3342,8 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
|
||||
|
||||
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
|
||||
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
|
||||
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
|
||||
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
|
||||
if (t == GGML_TYPE_F32) return 16u;
|
||||
return (uint32_t) ggml_type_size(t);
|
||||
};
|
||||
return {
|
||||
/* 0 WorkGroupSize */ state.workgroup_size,
|
||||
@ -3849,10 +3857,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
const uint32_t fa_sgs = fa.first.subgroup_size;
|
||||
const bool fa_ds = fa.first.subgroup_size == 0;
|
||||
|
||||
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
|
||||
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
|
||||
const void * spv_data = nullptr;
|
||||
size_t spv_size = 0;
|
||||
if (use_mmq) {
|
||||
const char *name = nullptr;
|
||||
if (bf16_kv) {
|
||||
spv_data = flash_attn_f32_f16_fp32_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_len;
|
||||
name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16";
|
||||
} else if (use_mmq) {
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
|
||||
@ -3862,6 +3876,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
spv_size = flash_attn_f32_f16_fp32_int8_len;
|
||||
}
|
||||
#endif
|
||||
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
||||
} else {
|
||||
if (device->fp16) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
|
||||
@ -3870,8 +3885,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
spv_data = flash_attn_f32_f16_fp32_data;
|
||||
spv_size = flash_attn_f32_f16_fp32_len;
|
||||
}
|
||||
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
||||
}
|
||||
const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
||||
@ -3889,11 +3904,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
const uint32_t fa_sgs = fa.first.subgroup_size;
|
||||
const bool fa_ds = fa.first.subgroup_size == 0;
|
||||
|
||||
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
|
||||
|
||||
const void * spv_data;
|
||||
size_t spv_size;
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
|
||||
const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
|
||||
const char *name;
|
||||
if (bf16_kv) {
|
||||
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
if (!device->coopmat_bf16_support) continue;
|
||||
spv_data = flash_attn_f32_f16_bf16_cm1_data;
|
||||
spv_size = flash_attn_f32_f16_bf16_cm1_len;
|
||||
name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1";
|
||||
#else
|
||||
continue;
|
||||
#endif
|
||||
} else {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
|
||||
name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
|
||||
}
|
||||
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
||||
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
||||
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
||||
@ -3911,10 +3940,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
const bool aligned = fa.first.aligned;
|
||||
const bool f32acc = fa.first.f32acc;
|
||||
|
||||
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
|
||||
const void * spv_data;
|
||||
size_t spv_size;
|
||||
const char * name;
|
||||
if (aligned) {
|
||||
if (bf16_kv) {
|
||||
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
if (!device->coopmat2_bf16_support) continue;
|
||||
spv_data = flash_attn_f32_f16_bf16_cm2_data;
|
||||
spv_size = flash_attn_f32_f16_bf16_cm2_len;
|
||||
name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2";
|
||||
#else
|
||||
continue;
|
||||
#endif
|
||||
} else if (aligned) {
|
||||
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
|
||||
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
|
||||
} else {
|
||||
@ -5784,46 +5823,72 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
found_fp16_256 = false,
|
||||
found_fp32_128 = false,
|
||||
found_fp32_256 = false;
|
||||
bool found_bf16_128 = false,
|
||||
found_bf16_256 = false;
|
||||
// need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
|
||||
// with 32x16x16 and 256 with 32x32x16.
|
||||
for (auto &prop : flexible_dimensions) {
|
||||
if (prop.saturatingAccumulation == VK_FALSE &&
|
||||
prop.scope == VK_SCOPE_WORKGROUP_KHR &&
|
||||
prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
||||
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
||||
prop.scope == VK_SCOPE_WORKGROUP_KHR) {
|
||||
|
||||
if (prop.workgroupInvocations == 128 &&
|
||||
prop.MGranularity <= 32 &&
|
||||
prop.NGranularity <= 16 &&
|
||||
prop.KGranularity <= 16) {
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
||||
found_fp16_128 = true;
|
||||
if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
||||
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
||||
|
||||
if (prop.workgroupInvocations == 128 &&
|
||||
prop.MGranularity <= 32 &&
|
||||
prop.NGranularity <= 16 &&
|
||||
prop.KGranularity <= 16) {
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
||||
found_fp16_128 = true;
|
||||
}
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
||||
found_fp32_128 = true;
|
||||
}
|
||||
}
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
||||
found_fp32_128 = true;
|
||||
if (prop.workgroupInvocations == 256 &&
|
||||
prop.MGranularity <= 32 &&
|
||||
prop.NGranularity <= 32 &&
|
||||
prop.KGranularity <= 16) {
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
||||
found_fp16_256 = true;
|
||||
}
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
||||
found_fp32_256 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (prop.workgroupInvocations == 256 &&
|
||||
prop.MGranularity <= 32 &&
|
||||
prop.NGranularity <= 32 &&
|
||||
prop.KGranularity <= 16) {
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
||||
found_fp16_256 = true;
|
||||
|
||||
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
||||
|
||||
if (prop.workgroupInvocations == 128 &&
|
||||
prop.MGranularity <= 32 &&
|
||||
prop.NGranularity <= 16 &&
|
||||
prop.KGranularity <= 16) {
|
||||
found_bf16_128 = true;
|
||||
}
|
||||
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
||||
found_fp32_256 = true;
|
||||
if (prop.workgroupInvocations == 256 &&
|
||||
prop.MGranularity <= 32 &&
|
||||
prop.NGranularity <= 32 &&
|
||||
prop.KGranularity <= 16) {
|
||||
found_bf16_256 = true;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if (found_fp16_128 && found_fp16_256 &&
|
||||
found_fp32_128 && found_fp32_256 &&
|
||||
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
|
||||
device->coopmat2 = true;
|
||||
device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256;
|
||||
device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
||||
}
|
||||
}
|
||||
@ -9448,7 +9513,8 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
const uint32_t Br = params.block_rows;
|
||||
const uint32_t Bc = params.block_cols;
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
// BF16 uses the fp32 shader (FLOAT_TYPE=float)
|
||||
const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
|
||||
|
||||
@ -9489,7 +9555,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
return supported;
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t Br = params.block_rows;
|
||||
const uint32_t Bc = params.block_cols;
|
||||
@ -9519,8 +9585,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
||||
const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
|
||||
|
||||
// BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes)
|
||||
const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4;
|
||||
const uint32_t osh_stride = params.row_split * MatBr / 4;
|
||||
const uint32_t pvsh = MatBc * osh_stride * f16vec4;
|
||||
const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size;
|
||||
|
||||
const uint32_t slope = Br * acctype;
|
||||
|
||||
@ -9589,7 +9657,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
|
||||
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16;
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
@ -16400,6 +16468,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
switch (t) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@ -16415,6 +16484,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
|
||||
return false;
|
||||
}
|
||||
if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) {
|
||||
return false;
|
||||
}
|
||||
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
|
||||
// scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
|
||||
return false;
|
||||
|
||||
@ -97,8 +97,17 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
|
||||
#define FA_TYPE_Q5_0 6u
|
||||
#define FA_TYPE_Q5_1 7u
|
||||
#define FA_TYPE_Q8_0 8u
|
||||
#define FA_TYPE_BF16 30u
|
||||
#define FA_TYPE_Q1_0 41u
|
||||
|
||||
#if defined(BFLOAT16)
|
||||
#define O_TYPE float
|
||||
#define O_TYPEV4 vec4
|
||||
#else
|
||||
#define O_TYPE FLOAT_TYPE
|
||||
#define O_TYPEV4 FLOAT_TYPEV4
|
||||
#endif
|
||||
|
||||
// Number of matrix elements per buffer block, derived from the K/V type spec
|
||||
// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1
|
||||
// and bypasses the dequant path entirely. Quants follow their ggml block sizes.
|
||||
@ -111,6 +120,7 @@ uint fa_block_elems(uint ty) {
|
||||
case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0);
|
||||
case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1);
|
||||
case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0);
|
||||
case FA_TYPE_BF16: return 1u;
|
||||
case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere
|
||||
default: return 1u;
|
||||
}
|
||||
@ -248,7 +258,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV / 4 + c;
|
||||
data_ov4[o_offset + offset] = D_TYPEV4(elems);
|
||||
|
||||
@ -6,6 +6,10 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#if defined(BFLOAT16)
|
||||
#extension GL_EXT_bfloat16 : enable
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
@ -14,7 +18,9 @@
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#if !defined(BFLOAT16)
|
||||
#include "flash_attn_dequant.glsl"
|
||||
#endif
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
@ -27,32 +33,32 @@ const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
||||
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];};
|
||||
layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
const uint32_t qstride = HSK_pad / 4 + 2;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qstride];
|
||||
|
||||
const uint psh_stride = Br / 4 + 2;
|
||||
shared f16vec4 Psh[Bc * psh_stride];
|
||||
shared FLOAT_TYPEV4 Psh[Bc * psh_stride];
|
||||
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
|
||||
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
|
||||
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2;
|
||||
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
|
||||
const uint vsh_stride = v_cols;
|
||||
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
|
||||
shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
|
||||
|
||||
const uint32_t osh_stride = row_split * MatBr / 4;
|
||||
shared f16vec4 pvsh[MatBc * osh_stride];
|
||||
shared O_TYPEV4 pvsh[MatBc * osh_stride];
|
||||
|
||||
shared ACC_TYPE slope[Br];
|
||||
|
||||
@ -76,7 +82,7 @@ void main() {
|
||||
if ((HSK % 16) != 0) {
|
||||
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Br * qstride) {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
Qf[i + tid] = FLOAT_TYPEV4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
@ -89,15 +95,15 @@ void main() {
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
f16vec4 Of[rows_per_thread][d_per_thread];
|
||||
O_TYPEV4 Of[rows_per_thread][d_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
|
||||
Of[r][d] = f16vec4(0.0);
|
||||
Of[r][d] = O_TYPEV4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -222,15 +228,18 @@ void main() {
|
||||
uint32_t d = (idx + tid) % (HSK_pad / 4);
|
||||
uint32_t c = (idx + tid) / (HSK_pad / 4);
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
|
||||
#if !defined(BFLOAT16)
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,16 +253,16 @@ void main() {
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
// If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
|
||||
// If not, f16 K is loaded directly from global memory if aligned, otherwise
|
||||
// If not, K is loaded directly from global memory if aligned, otherwise
|
||||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If K is not type f16, then it is always staged for dequantization.
|
||||
// If K is a quant type, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
// For quants we always need to dequant into kvsh; for f16 we can load
|
||||
// For quants we always need to dequant into kvsh; for f16/bf16 we can load
|
||||
// directly from global memory when alignment / bounds allow it.
|
||||
const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK;
|
||||
if (stage_k) {
|
||||
@ -262,15 +271,18 @@ void main() {
|
||||
uint32_t col_vec = (idx + tid) % (MatBr / 4);
|
||||
uint32_t row = (idx + tid) / (MatBr / 4);
|
||||
if (idx + tid < Bc * MatBr / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
|
||||
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
|
||||
#if !defined(BFLOAT16)
|
||||
if (USE_DECODE_K) {
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE_K;
|
||||
uint iqs = (coord % BLOCK_SIZE_K);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
} else {
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -357,7 +369,7 @@ void main() {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
|
||||
Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local];
|
||||
}
|
||||
}
|
||||
|
||||
@ -368,10 +380,10 @@ void main() {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
|
||||
const uint row = tile_row(r);
|
||||
if (KV_bounds_check && j * Bc + col >= KV) {
|
||||
Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
|
||||
Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f);
|
||||
} else {
|
||||
const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
|
||||
const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
|
||||
const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
|
||||
[[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
|
||||
Lf[r + vec_idx] += Pf[vec_idx];
|
||||
}
|
||||
@ -385,15 +397,18 @@ void main() {
|
||||
uint32_t d = (idx + tid) % (HSV_pad / 4);
|
||||
uint32_t c = (idx + tid) / (HSV_pad / 4);
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
|
||||
f16vec4 V_Tf = f16vec4(0);
|
||||
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
|
||||
#if !defined(BFLOAT16)
|
||||
if (USE_DECODE_V) {
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE_V;
|
||||
uint iqs = (coord % BLOCK_SIZE_V);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -409,7 +424,7 @@ void main() {
|
||||
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
|
||||
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
|
||||
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
// Preload V tiles for [Bc, 16 * num subgroups]
|
||||
const uint v_rows = Bc;
|
||||
@ -417,11 +432,11 @@ void main() {
|
||||
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
|
||||
|
||||
// If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
|
||||
// If not, f16 V is loaded directly from global memory if aligned, otherwise
|
||||
// If not, V is loaded directly from global memory if aligned, otherwise
|
||||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If V is not type f16, then it is always staged for dequantization.
|
||||
// If V is a quant type, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
// For quants we always preload via kvsh. For f16 we only preload when
|
||||
// For quants we always preload via kvsh. For f16/bf16 we only preload when
|
||||
// alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4).
|
||||
const bool stage_v = USE_DECODE_V || KV_bounds_check;
|
||||
if (stage_v) {
|
||||
@ -438,13 +453,16 @@ void main() {
|
||||
const uint iqs = coord % BLOCK_SIZE_V;
|
||||
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
#if !defined(BFLOAT16)
|
||||
if (USE_DECODE_V) {
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
} else {
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
}
|
||||
} else {
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -459,7 +477,7 @@ void main() {
|
||||
|
||||
if (SHMEM_STAGING == 0) {
|
||||
if (!USE_DECODE_V && !KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
// F16/BF16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
@ -573,7 +591,7 @@ void main() {
|
||||
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
Of[r][d_local] *= float16_t(ms);
|
||||
Of[r][d_local] *= O_TYPE(ms);
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
@ -591,7 +609,7 @@ void main() {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d_local] *= float16_t(Lfrcp[r]);
|
||||
Of[r][d_local] *= O_TYPE(Lfrcp[r]);
|
||||
#if defined(FLOAT_TYPE_MAX)
|
||||
Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
|
||||
#endif
|
||||
|
||||
@ -8,6 +8,10 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
|
||||
#if defined(BFLOAT16)
|
||||
#extension GL_EXT_bfloat16 : enable
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
@ -21,7 +25,9 @@
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#if !defined(BFLOAT16)
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
#endif
|
||||
|
||||
// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V.
|
||||
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K {
|
||||
@ -31,6 +37,7 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_
|
||||
uint8_t raw[FaBlockBytesV];
|
||||
};
|
||||
|
||||
#if !defined(BFLOAT16)
|
||||
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock);
|
||||
@ -91,6 +98,7 @@ f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], co
|
||||
#define FADECODEK , faDecodeK
|
||||
#define FADECODEV , faDecodeV
|
||||
#endif
|
||||
#endif
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
@ -195,15 +203,15 @@ void main() {
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
Q *= Q_TYPE(p.scale);
|
||||
Qf16 = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
@ -291,16 +299,20 @@ void main() {
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
|
||||
#if defined(BFLOAT16)
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
|
||||
#else
|
||||
const bool k_use_decode = (bs_k > 1u);
|
||||
if (k_use_decode) {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK);
|
||||
} else {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
|
||||
}
|
||||
#endif
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (LOGIT_SOFTCAP) {
|
||||
@ -351,22 +363,26 @@ void main() {
|
||||
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
|
||||
}
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
|
||||
|
||||
// compute rowsum by multiplying by matrix of all ones.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
|
||||
|
||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
#if defined(BFLOAT16)
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
|
||||
#else
|
||||
const bool v_use_decode = (bs_v > 1u);
|
||||
if (v_use_decode) {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV);
|
||||
} else {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
|
||||
}
|
||||
#endif
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
@ -378,7 +394,7 @@ void main() {
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
|
||||
O *= coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag);
|
||||
O = coopMatMulAdd(P_A, V, O);
|
||||
}
|
||||
|
||||
@ -427,7 +443,7 @@ void main() {
|
||||
if (sink > Mr[i]) {
|
||||
ms = exp(Mr[i] - sink);
|
||||
|
||||
O[i] *= float16_t(ms);
|
||||
O[i] *= O_TYPE(ms);
|
||||
} else {
|
||||
vs = exp(sink - Mr[i]);
|
||||
}
|
||||
|
||||
@ -28,6 +28,9 @@ layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[];
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0;
|
||||
layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0;
|
||||
|
||||
layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16;
|
||||
layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16;
|
||||
|
||||
// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16
|
||||
// views, used by the MMQ K-side hot path for fast 4-uint loads.
|
||||
layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32;
|
||||
@ -99,6 +102,9 @@ layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 dat
|
||||
return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \
|
||||
}
|
||||
|
||||
#define FA_DEQUANT4_BF16(BUF) \
|
||||
return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4])));
|
||||
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
switch (FaTypeK) {
|
||||
@ -108,6 +114,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0)
|
||||
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1)
|
||||
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0)
|
||||
case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16)
|
||||
}
|
||||
} else {
|
||||
switch (FaTypeV) {
|
||||
@ -117,6 +124,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0)
|
||||
case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1)
|
||||
case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0)
|
||||
case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16)
|
||||
}
|
||||
}
|
||||
return FLOAT_TYPEV4(0);
|
||||
|
||||
@ -662,6 +662,28 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
const std::map<std::string, std::string> fa_bf16_dict = {
|
||||
{"FLOAT_TYPE", "bfloat16_t"},
|
||||
{"FLOAT_TYPEV2", "bf16vec2"},
|
||||
{"FLOAT_TYPEV4", "bf16vec4"},
|
||||
{"ACC_TYPE", "float"},
|
||||
{"ACC_TYPEV2", "vec2"},
|
||||
{"ACC_TYPEV4", "vec4"},
|
||||
{"BFLOAT16", "1"},
|
||||
};
|
||||
|
||||
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp",
|
||||
merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}),
|
||||
true, true, false, false);
|
||||
#endif
|
||||
|
||||
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp",
|
||||
merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}),
|
||||
true, false, true, false);
|
||||
#endif
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}};
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user