mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix cache loading/saving for MLA models and split mode graph (#1884)
This commit is contained in:
parent
4fbd0c441b
commit
d2da6da05c
@ -515,6 +515,10 @@ struct llama_model {
|
||||
return tensor_overrides;
|
||||
}
|
||||
|
||||
bool is_mla_model() const {
|
||||
return arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4;
|
||||
}
|
||||
|
||||
size_t cache_size(int il, ggml_type type_k, ggml_type type_v, uint32_t kv_size, int mla_attn, int n_seq_max, bool flash_attn) const;
|
||||
|
||||
void set_tensor_overrides(const llama_model_params& params);
|
||||
|
||||
@ -8529,16 +8529,27 @@ struct llama_data_read {
|
||||
void read_kv_cache_data_split(llama_context * ctx, ggml_tensor * tensor, const uint8_t * data, size_t head, size_t row_size, int nrows, int il) {
|
||||
GGML_ASSERT(il >= 0 && il < int(ctx->model.layers.size()));
|
||||
GGML_ASSERT(ggml_internal_get_type_traits(tensor->type).row_meta_size == 0);
|
||||
std::vector<uint8_t> aux;
|
||||
auto extra = (ggml_split_tensor_t *)tensor->extra;
|
||||
if (ctx->model.is_mla_model()) {
|
||||
GGML_ASSERT(extra);
|
||||
GGML_ASSERT(ctx->cparams.mla_attn == 3);
|
||||
for (int id = 0; id < extra->n_device; ++id) {
|
||||
auto split = extra->splits[id];
|
||||
if (!split) continue;
|
||||
GGML_ASSERT(split->type == tensor->type);
|
||||
ggml_backend_tensor_set(split, data, head*row_size, nrows*row_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
bool is_recurrent = ctx->model.hparams.recurrent_layer_arr[il];
|
||||
auto kv = is_recurrent ? nullptr : get_kv_cache_split_tensor(tensor, ctx->model.layers[il]);
|
||||
auto extra = (ggml_split_tensor_t *)tensor->extra;
|
||||
auto kv_extra = kv ? (ggml_split_tensor_t *)kv->extra : nullptr;
|
||||
GGML_ASSERT(extra && (is_recurrent || kv_extra));
|
||||
auto ne = kv ? kv->ne[1] : tensor->ne[0];
|
||||
size_t sum_ne = 0;
|
||||
size_t sum_split_row_size = 0;
|
||||
GGML_ASSERT(row_size == ggml_row_size(tensor->type, ne));
|
||||
std::vector<uint8_t> aux;
|
||||
for (int id = 0; id < extra->n_device; ++id) {
|
||||
auto split = extra->splits[id];
|
||||
auto kv_split = kv_extra ? kv_extra->splits[id] : nullptr;
|
||||
@ -8902,7 +8913,18 @@ struct llama_data_write_buffer : llama_data_write {
|
||||
throw std::runtime_error(std::string{"Split cache for type "} + ggml_type_name(tensor->type) + " is not supported");
|
||||
}
|
||||
GGML_ASSERT(il >= 0 && il < int(model.layers.size()));
|
||||
if (model.hparams.recurrent_layer_arr[il]) {
|
||||
if (model.is_mla_model()) {
|
||||
// For MLA models, the cache is replacated on all GPUs when using split mode graph, so it is
|
||||
// enough to get the data from the 1st device that has a copy
|
||||
auto extra = (const ggml_split_tensor_t *)tensor->extra;
|
||||
for (int id = 0; id < extra->n_device; ++id) {
|
||||
if (extra->splits[id]) {
|
||||
ggml_backend_tensor_get(extra->splits[id], ptr, offset, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (model.hparams.recurrent_layer_arr[il]) {
|
||||
get_tensor_data_split(ptr, tensor, aux_buffer, offset, size);
|
||||
} else {
|
||||
auto kv = get_kv_cache_split_tensor(tensor, model.layers[il]);
|
||||
@ -9037,6 +9059,18 @@ struct llama_data_write_file : llama_data_write {
|
||||
|
||||
void get_tensor_data_split(const struct ggml_tensor * tensor, size_t offset, size_t size, int il) {
|
||||
GGML_ASSERT(il >= 0 && il < int(model.layers.size()));
|
||||
if (model.is_mla_model()) {
|
||||
// MLA models have the cache replicated on all devices. Hence, it is enough to get it
|
||||
// from the 1st device that has it.
|
||||
auto extra = (ggml_split_tensor_t *)tensor->extra;
|
||||
for (int id = 0; id < extra->n_device; ++id) {
|
||||
if (extra->splits[id]) {
|
||||
ggml_backend_tensor_get(extra->splits[id], temp_buffer.data(), offset, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
auto kv = get_kv_cache_split_tensor(tensor, model.layers[il]);
|
||||
temp_buffer.resize(size);
|
||||
llama_data_write_buffer::get_tensor_data_split(temp_buffer.data(), tensor, kv, aux_buffer, offset, size);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user