CPU FA: check if types are supported

This commit is contained in:
Kawrakow 2026-03-31 17:10:41 +03:00
parent 8b575c4b1f
commit 97b1a69998

View File

@ -17,6 +17,7 @@
#include <cstdint>
#include <cstring>
#include <cmath>
#include <unordered_set>
namespace {
inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
@ -106,6 +107,36 @@ size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) {
return size;
}
static inline const std::unordered_set<ggml_type> & supported_kv_types() {
#ifdef GGML_IQK_FA_ALL_QUANTS
static std::unordered_set<ggml_type> k_supported = {
GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q8_KV, GGML_TYPE_Q6_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL
};
#else
static std::unordered_set<ggml_type> k_supported = {
GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q8_KV, GGML_TYPE_Q6_0,
};
#endif
return k_supported;
}
static inline bool are_kv_types_supported(ggml_type type_k, ggml_type type_v) {
if (type_k == GGML_TYPE_BF16) {
if (type_v != type_k) {
return false;
}
#ifdef __AVX512BF16__
return true;
#else
return false;
#endif
}
auto & supported = supported_kv_types();
auto it_k = supported.find(type_k);
auto it_v = supported.find(type_v);
return it_k != supported.end() && it_v != supported.end();
}
// TODO: get the ggml_type enum here without polution
//
extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
@ -136,6 +167,23 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
if (auto type_k = ggml_type(int_type_k_in), type_v = ggml_type(int_type_v); !are_kv_types_supported(type_k, type_v)) {
if (ith == 0) {
fprintf(stderr, "\n==================== KV cache types %s, %s are not supported on the CPU\n",
ggml_type_name(type_k), ggml_type_name(type_v));
auto & supported = supported_kv_types();
fprintf(stderr, "Sopprted types are:\n");
for (auto type : supported) {
fprintf(stderr, " %s\n", ggml_type_name(type));
}
#ifdef __AVX512BF16__
fprintf(stderr, " %s, but only if K and V are both %s\n", ggml_type_name(GGML_TYPE_BF16), ggml_type_name(GGML_TYPE_BF16));
#endif
}
barrier(barrier_data);
GGML_ABORT("Fatal error");
}
if (n_swa > 0) {
constexpr int kMinBatch = 256;
int ntokens = std::max(kMinBatch, neq1);