diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 0c9677b1..aa01237d 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace { inline uint32_t simple_gcd(uint32_t a, uint32_t b) { @@ -106,6 +107,36 @@ size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) { return size; } +static inline const std::unordered_set & supported_kv_types() { +#ifdef GGML_IQK_FA_ALL_QUANTS + static std::unordered_set k_supported = { + GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q8_KV, GGML_TYPE_Q6_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL + }; +#else + static std::unordered_set k_supported = { + GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q8_KV, GGML_TYPE_Q6_0, + }; +#endif + return k_supported; +} + +static inline bool are_kv_types_supported(ggml_type type_k, ggml_type type_v) { + if (type_k == GGML_TYPE_BF16) { + if (type_v != type_k) { + return false; + } +#ifdef __AVX512BF16__ + return true; +#else + return false; +#endif + } + auto & supported = supported_kv_types(); + auto it_k = supported.find(type_k); + auto it_v = supported.find(type_v); + return it_k != supported.end() && it_v != supported.end(); +} + // TODO: get the ggml_type enum here without polution // extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, @@ -136,6 +167,23 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; + if (auto type_k = ggml_type(int_type_k_in), type_v = ggml_type(int_type_v); !are_kv_types_supported(type_k, type_v)) { + if (ith == 0) { + fprintf(stderr, "\n==================== KV cache types %s, %s are not supported on the CPU\n", + ggml_type_name(type_k), ggml_type_name(type_v)); + auto & supported = supported_kv_types(); + fprintf(stderr, "Sopprted types are:\n"); + for (auto type : supported) { + fprintf(stderr, " %s\n", ggml_type_name(type)); + } +#ifdef __AVX512BF16__ + fprintf(stderr, " %s, but only if K and V are both %s\n", ggml_type_name(GGML_TYPE_BF16), ggml_type_name(GGML_TYPE_BF16)); +#endif + } + barrier(barrier_data); + GGML_ABORT("Fatal error"); + } + if (n_swa > 0) { constexpr int kMinBatch = 256; int ntokens = std::max(kMinBatch, neq1);