mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
CPU FA: check if types are supported
This commit is contained in:
parent
8b575c4b1f
commit
97b1a69998
@ -17,6 +17,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace {
|
||||
inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
|
||||
@ -106,6 +107,36 @@ size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) {
|
||||
return size;
|
||||
}
|
||||
|
||||
static inline const std::unordered_set<ggml_type> & supported_kv_types() {
|
||||
#ifdef GGML_IQK_FA_ALL_QUANTS
|
||||
static std::unordered_set<ggml_type> k_supported = {
|
||||
GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q8_KV, GGML_TYPE_Q6_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL
|
||||
};
|
||||
#else
|
||||
static std::unordered_set<ggml_type> k_supported = {
|
||||
GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q8_KV, GGML_TYPE_Q6_0,
|
||||
};
|
||||
#endif
|
||||
return k_supported;
|
||||
}
|
||||
|
||||
static inline bool are_kv_types_supported(ggml_type type_k, ggml_type type_v) {
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != type_k) {
|
||||
return false;
|
||||
}
|
||||
#ifdef __AVX512BF16__
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
auto & supported = supported_kv_types();
|
||||
auto it_k = supported.find(type_k);
|
||||
auto it_v = supported.find(type_v);
|
||||
return it_k != supported.end() && it_v != supported.end();
|
||||
}
|
||||
|
||||
// TODO: get the ggml_type enum here without polution
|
||||
//
|
||||
extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
@ -136,6 +167,23 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
|
||||
if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
|
||||
|
||||
if (auto type_k = ggml_type(int_type_k_in), type_v = ggml_type(int_type_v); !are_kv_types_supported(type_k, type_v)) {
|
||||
if (ith == 0) {
|
||||
fprintf(stderr, "\n==================== KV cache types %s, %s are not supported on the CPU\n",
|
||||
ggml_type_name(type_k), ggml_type_name(type_v));
|
||||
auto & supported = supported_kv_types();
|
||||
fprintf(stderr, "Sopprted types are:\n");
|
||||
for (auto type : supported) {
|
||||
fprintf(stderr, " %s\n", ggml_type_name(type));
|
||||
}
|
||||
#ifdef __AVX512BF16__
|
||||
fprintf(stderr, " %s, but only if K and V are both %s\n", ggml_type_name(GGML_TYPE_BF16), ggml_type_name(GGML_TYPE_BF16));
|
||||
#endif
|
||||
}
|
||||
barrier(barrier_data);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
|
||||
if (n_swa > 0) {
|
||||
constexpr int kMinBatch = 256;
|
||||
int ntokens = std::max(kMinBatch, neq1);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user