Mixed KV cache (#1599)

This commit is contained in:
Kawrakow 2026-04-09 09:33:17 +02:00 committed by GitHub
parent 5950d0259e
commit 9db5d9907e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 148 additions and 7 deletions

View File

@ -1260,6 +1260,50 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cache_type_v = argv[++i];
return true;
}
if (arg == "-ctk-first" || arg == "--cache-type-k-first") {
CHECK_ARG
auto p = string_split(argv[i], ",");
if (p.size() != 2) {
invalid_param = true;
} else {
params.type_k_first = p[0];
params.n_k_first = std::stoi(p[1].c_str());
}
return true;
}
if (arg == "-ctk-last" || arg == "--cache-type-k-last") {
CHECK_ARG
auto p = string_split(argv[i], ",");
if (p.size() != 2) {
invalid_param = true;
} else {
params.type_k_last = p[0];
params.n_k_last = std::stoi(p[1].c_str());
}
return true;
}
if (arg == "-ctv-first" || arg == "--cache-type-v-first") {
CHECK_ARG
auto p = string_split(argv[i], ",");
if (p.size() != 2) {
invalid_param = true;
} else {
params.type_v_first = p[0];
params.n_v_first = std::stoi(p[1].c_str());
}
return true;
}
if (arg == "-ctv-last" || arg == "--cache-type-v-last") {
CHECK_ARG
auto p = string_split(argv[i], ",");
if (p.size() != 2) {
invalid_param = true;
} else {
params.type_v_last = p[0];
params.n_v_last = std::stoi(p[1].c_str());
}
return true;
}
if (arg == "-ctkd" || arg == "--cache-type-k-draft") {
params.speculative.cache_type_k = argv[++i];
return true;
@ -2470,6 +2514,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" });
options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() });
options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() });
options.push_back({ "*", "-ctk-first, --cache-type-k-first TYPE,N", "KV cache data type for the first N layers of K (default: %s,-1)", params.type_k_first.c_str() });
options.push_back({ "*", "-ctv-last, --cache-type-k-last TYPE,N", "KV cache data type for the last N layers of K (default: %s,-1)", params.type_k_last.c_str() });
options.push_back({ "*", "-ctv-first, --cache-type-v-first TYPE,N", "KV cache data type for the first N layers of V (default: %s,-1)", params.type_v_first.c_str() });
options.push_back({ "*", "-ctk-last, --cache-type-v-last TYPE,N", "KV cache data type for the last N layers of V (default: %s,-1)", params.type_v_last.c_str() });
options.push_back({ "*", "-ctkd, --cache-type-k-draft TYPE", "KV cache data type for K for the draft model" });
options.push_back({ "*", "-ctvd, --cache-type-v-draft TYPE", "KV cache data type for V for the draft model" });
@ -3357,6 +3405,14 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params
mparams.worst_graph_tokens = params.worst_graph_tokens;
mparams.type_k = kv_cache_type_from_str(params.cache_type_k);
mparams.type_v = kv_cache_type_from_str(params.cache_type_v);
mparams.type_k_first = kv_cache_type_from_str(params.type_k_first);
mparams.type_k_last = kv_cache_type_from_str(params.type_k_last );
mparams.type_v_first = kv_cache_type_from_str(params.type_v_first);
mparams.type_v_last = kv_cache_type_from_str(params.type_v_last );
mparams.n_k_first = params.n_k_first;
mparams.n_k_last = params.n_k_last;
mparams.n_v_first = params.n_v_first;
mparams.n_v_last = params.n_v_last;
mparams.max_ctx_size = params.n_ctx;
mparams.n_seq_max = params.n_parallel;
mparams.n_ubatch = get_batch_ubatch(params).second;
@ -3464,6 +3520,20 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa
if (!cparams.flash_attn && ggml_is_quantized(cparams.type_v)) {
throw std::runtime_error("Quantized V cache cannot be used without flash attention");
}
cparams.type_k_first = kv_cache_type_from_str(params.type_k_first);
cparams.type_k_last = kv_cache_type_from_str(params.type_k_last );
cparams.type_v_first = kv_cache_type_from_str(params.type_v_first);
cparams.type_v_last = kv_cache_type_from_str(params.type_v_last );
cparams.n_k_first = params.n_k_first;
cparams.n_k_last = params.n_k_last;
cparams.n_v_first = params.n_v_first;
cparams.n_v_last = params.n_v_last;
if (!cparams.flash_attn && ggml_is_quantized(cparams.type_v_first) && cparams.n_v_first > 0) {
throw std::runtime_error("Quantized V cache cannot be used without flash attention");
}
if (!cparams.flash_attn && ggml_is_quantized(cparams.type_v_last) && cparams.n_v_last > 0) {
throw std::runtime_error("Quantized V cache cannot be used without flash attention");
}
if (!params.offload_policy.empty()) cparams.offload_policy = (void *)&params.offload_policy;
if (!params.cuda_params.empty()) cparams.cuda_params = (void *)params.cuda_params.data();

View File

@ -285,7 +285,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<std::string> ban_phrases; // strings that are banned in generation
int32_t banned_n = 1; // number of tokens that are banned in the phrase
size_t n_buffer = 0; // number of token buffers for string ban
size_t n_buffer = 0; // number of token buffers for string ban
bool can_ban_phrases = true; // whether to ban strings
std::vector<llama_model_kv_override> kv_overrides;
@ -373,6 +373,15 @@ struct gpt_params {
std::string reduce_type = "f16";
std::string type_k_first = "f16";
std::string type_k_last = "f16";
std::string type_v_first = "f16";
std::string type_v_last = "f16";
int32_t n_k_first = -1;
int32_t n_k_last = -1;
int32_t n_v_first = -1;
int32_t n_v_last = -1;
// multimodal models (see examples/mtmd)
common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model

View File

@ -385,6 +385,14 @@ extern "C" {
int32_t fit_margin;
bool fit;
int32_t worst_graph_tokens;
enum ggml_type type_k_first;
enum ggml_type type_k_last;
enum ggml_type type_v_first;
enum ggml_type type_v_last;
int32_t n_k_first;
int32_t n_k_last;
int32_t n_v_first;
int32_t n_v_last;
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split;
@ -453,6 +461,14 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
enum ggml_type type_reduce; // data type for reduce operations
enum ggml_type type_k_first;
enum ggml_type type_k_last;
enum ggml_type type_v_first;
enum ggml_type type_v_last;
int32_t n_k_first;
int32_t n_k_last;
int32_t n_v_first;
int32_t n_v_last;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)

View File

@ -720,7 +720,15 @@ static bool llama_kv_cache_init(
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload) {
bool offload,
ggml_type type_k_first,
ggml_type type_k_last,
ggml_type type_v_first,
ggml_type type_v_last,
int32_t n_k_first,
int32_t n_k_last,
int32_t n_v_first,
int32_t n_v_last) {
const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams;
@ -926,10 +934,30 @@ static bool llama_kv_cache_init(
split_cache_i = false;
}
int n_embd_head_v = hparams.n_embd_head_v(i);
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
auto this_type_k = type_k;
if (type_k_first != type_k && n_k_first > 0 && i < n_k_first) {
this_type_k = type_k_first;
}
if (type_k_last != type_k && n_k_last > 0 && i >= n_layer - n_k_last) {
this_type_k = type_k_last;
}
if (this_type_k != type_k) {
LLAMA_LOG_INFO("================= Setting K-cache type in layer %2d to %s\n", i, ggml_type_name(this_type_k));
}
k = ggml_new_tensor_2d(ctx, this_type_k, n_embd_head_k, n_head_kv*kv_size);
int64_t v_ne = int64_t(n_embd_v_row)*kv_size;
v = ggml_new_tensor_1d(ctx, type_v, v_ne);
auto this_type_v = type_v;
if (type_v_first != type_v && n_v_first > 0 && i < n_v_first) {
this_type_v = type_v_first;
}
if (type_v_last != type_v && n_v_last > 0 && i >= n_layer - n_v_last) {
this_type_v = type_v_last;
}
if (this_type_v != type_v) {
LLAMA_LOG_INFO("================= Setting V-cache type in layer %2d to %s\n", i, ggml_type_name(this_type_v));
}
v = ggml_new_tensor_1d(ctx, this_type_v, v_ne);
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
@ -952,7 +980,7 @@ static bool llama_kv_cache_init(
LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n",
i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
}
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, this_type_k, n_embd_head_k, nhead_kv * kv_size);
auto split_name = k_name + '.' + std::to_string(is);
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
@ -963,7 +991,7 @@ static bool llama_kv_cache_init(
for (int is = 0; is < extra_V->n_device; ++is) {
auto split = extra_V->splits[is];
if (!split) continue;
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, this_type_v, split->ne[1] * kv_size);
auto split_name = v_name + '.' + std::to_string(is);
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
@ -4907,6 +4935,14 @@ struct llama_model_params llama_model_default_params() {
/*.fit_margin =*/ 0,
/*.fit =*/ false,
/*.worst_graph_tokens =*/ 0,
/*.type_first_k =*/ GGML_TYPE_F16,
/*.type_last_k =*/ GGML_TYPE_F16,
/*.type_first_v =*/ GGML_TYPE_F16,
/*.type_last_v =*/ GGML_TYPE_F16,
/*.n_first_k =*/ -1,
/*.n_last_k =*/ -1,
/*.n_first_v =*/ -1,
/*.n_last_v =*/ -1,
/*.tensor_split =*/ nullptr,
/*.rpc_servers =*/ nullptr,
/*.progress_callback =*/ nullptr,
@ -4962,6 +4998,14 @@ struct llama_context_params llama_context_default_params() {
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.type_reduce =*/ GGML_TYPE_F16,
/*.type_first_k =*/ GGML_TYPE_F16,
/*.type_last_k =*/ GGML_TYPE_F16,
/*.type_first_v =*/ GGML_TYPE_F16,
/*.type_last_v =*/ GGML_TYPE_F16,
/*.n_first_k =*/ -1,
/*.n_last_k =*/ -1,
/*.n_first_v =*/ -1,
/*.n_last_v =*/ -1,
/*.logits_all =*/ false,
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
@ -5670,7 +5714,9 @@ struct llama_context * llama_init_from_model(
}
ctx->backends.push_back(ctx->backend_cpu);
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv,
params.type_k_first, params.type_k_last, params.type_v_first, params.type_v_first,
params.n_k_first, params.n_k_last, params.n_v_first, params.n_v_last)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;