Fix MLA models with ngl < n_layer (#1870)

* Fix split mode graph with ngl < n_layer (MLA models)

* It is actually not related to split mode graph
This commit is contained in:
Kawrakow 2026-05-24 07:29:17 +03:00 committed by GitHub
parent 642c038ccd
commit 809a63bbb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -857,14 +857,13 @@ static bool llama_kv_cache_init(
}
if (is_mla_attn) {
bool have_wkv_b = true;
int n_have_wkv_b = 0;
for (auto& l : model.layers) {
// Under -sm graph mla>1, wk_b_pp (attn_kv_b) substitutes for wkv_b.
if (!l.wkv_b && !l.wk_b_pp) {
have_wkv_b = false;
break;
if (l.wkv_b || l.wk_b_pp) {
++n_have_wkv_b;
}
}
bool have_wkv_b = n_have_wkv_b > 0;
if (!have_wkv_b) {
if (cparams.mla_attn != 1) {
LLAMA_LOG_WARN("=========================================================\n");
@ -2490,6 +2489,7 @@ static void llm_prepare_mla(llama_model & model, int mla) {
// Second pass: for layers where wk_b came from the GGUF directly, produce
// wk_b_pp here. Only under -sm graph/attn AND mla > 1; mla=1 skips pp_opt.
int n_computed = 0;
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && mla > 1) {
int n_pp_to_compute = 0;
for (auto & l : model.layers) {
@ -2537,6 +2537,8 @@ static void llm_prepare_mla(llama_model & model, int mla) {
// transpose in build_deepseek2.cpp; skip here.
if (!l.wo || !l.wo->extra || !l.wk_b->extra) continue;
++n_computed;
// Per-rank wk_b slices: each lives on a single device as a regular CUDA
// tensor (not the split-buffer wrapper which lacks a get_tensor impl for
// split_dim=2). Read each rank's slice independently.
@ -2673,7 +2675,7 @@ static void llm_prepare_mla(llama_model & model, int mla) {
}
}
if (mla == 1 || model.split_mode == LLAMA_SPLIT_MODE_GRAPH) return;
if (mla == 1 || (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && n_computed == n_layer)) return;
n_to_compute = 0;
for (auto& l : model.layers) {
@ -2708,7 +2710,7 @@ static void llm_prepare_mla(llama_model & model, int mla) {
std::vector<char> tmp_buffer;
for (int il = 0; il < n_layer; ++il) {
auto& l = model.layers[il];
if (l.wkv_b || !l.wk_b || !l.wv_b) continue;
if (l.wkv_b || !l.wk_b || !l.wv_b || (l.wo && l.wo->extra)) continue;
auto wk_b = *l.wk_b;
auto wv_b = *l.wv_b;
if (!ggml_backend_buffer_is_host(l.wk_b->buffer)) {
@ -3278,7 +3280,7 @@ static bool llm_load_tensors(
llama_model_loader & ml,
llama_model & model,
int n_gpu_layers,
int mla_attn,
int & mla_attn,
enum llama_split_mode split_mode,
int main_gpu,
int max_gpu,
@ -3421,6 +3423,17 @@ static bool llm_load_tensors(
if (fit && device_count > 1) {
model.main_gpu = device_count - 1;
}
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA || model.arch == LLM_ARCH_MISTRAL4) {
if (model.n_gpu_layers > 0 && model.n_gpu_layers < model.hparams.n_layer && mla_attn != 3) {
LLAMA_LOG_WARN("=============================================================================\n");
LLAMA_LOG_WARN("MLA models with ngl < n_layer and split mode graph do not work with mla = %d\n", mla_attn);
LLAMA_LOG_WARN(" => changing mla to 3\n");
LLAMA_LOG_WARN("=============================================================================\n");
mla_attn = 3;
}
}
model.default_layer_device = std::vector<int32_t>(hparams.n_layer+1, device_count-1);
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
std::vector<llama_model_tensor_buft_override> overrides;
@ -6746,6 +6759,14 @@ struct llama_context * llama_init_from_model(
if (model->arch != LLM_ARCH_DEEPSEEK2 && model->arch != LLM_ARCH_GLM_DSA && model->arch != LLM_ARCH_MISTRAL4 && cparams.mla_attn != 0) {
cparams.mla_attn = 0;
} else {
if (model->n_gpu_layers > 0 && model->n_gpu_layers < model->hparams.n_layer && cparams.mla_attn != 3) {
LLAMA_LOG_WARN("=============================================================================\n");
LLAMA_LOG_WARN("MLA models with ngl < n_layer and split mode graph do not work with mla = %d\n", cparams.mla_attn);
LLAMA_LOG_WARN(" => changing mla to 3\n");
LLAMA_LOG_WARN("=============================================================================\n");
cparams.mla_attn = 3;
}
}
if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
if (cparams.reduce_type == GGML_TYPE_F16) {