diff --git a/src/graphs/build_deepseek2.cpp b/src/graphs/build_deepseek2.cpp index 63e6ae21..c6156860 100644 --- a/src/graphs/build_deepseek2.cpp +++ b/src/graphs/build_deepseek2.cpp @@ -13,7 +13,8 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention( ggml_tensor * rope_cache, float kq_scale, float attn_factor_scaled, bool use_f32_attn_precision, - bool is_lite) { + bool is_lite, + bool pp_opt) { if (!lctx.cparams.flash_attn || lctx.cparams.mla_attn < 1) { GGML_ABORT("-sm graph for MLA archs (DEEPSEEK2/GLM_DSA/MISTRAL4) requires -fa on and -mla >= 1. " "Got mla_attn=%d, flash_attn=%d.", @@ -140,49 +141,166 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention( row_size_cache, 0); cb(kv_cache, "kv_cache", il_id); - // wk_b is split per-head (split_dim=2); each rank's tensor already contains only its n_head_local heads. - auto wk_b_split = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra; - GGML_ASSERT(wk_b_split); - ggml_tensor * wk_b_local = wk_b_split->splits[id]; + // pp_opt (mla > 1, n_tokens >= 128, n_kv >= k_pp_opt_min_kv): materialize + // per-rank K/V from the latent cache and use standard flash_attn instead of + // FlashMLA-3 absorb. + constexpr int k_pp_opt_min_kv = 1024; + const bool tp_pp_opt = pp_opt + && (int)n_kv >= k_pp_opt_min_kv + && model.layers[il].wk_b + && model.layers[il].wv_b + && model.layers[il].wk_b_pp; - ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); - ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b_local, q_nope_perm); + ggml_tensor * kqv_2d; - ggml_tensor * q_combined = ggml_concat(ctx0, - ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0); - if (cparams.k_cache_hadamard) { - q_combined = ggml_hadamard(ctx0, q_combined, 64); + if (tp_pp_opt) { + // Per-rank wk_b/wv_b slices already exist from distribute_mla_tensors: + // wk_b_local_pp: [n_embd_head_qk_nope, kv_lora_rank, n_head_local] + // wv_b_local_pp: [kv_lora_rank, n_embd_head_v, n_head_local] + auto wk_b_pp_split_raw = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra; + auto wv_b_pp_split_raw = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra; + GGML_ASSERT(wk_b_pp_split_raw && wv_b_pp_split_raw); + ggml_tensor * wk_b_local_pp = wk_b_pp_split_raw->splits[id]; + ggml_tensor * wv_b_local_pp = wv_b_pp_split_raw->splits[id]; + + ggml_tensor * kv_cache_nope = ggml_view_2d(ctx0, cache_local, + kv_lora_rank, n_kv, + row_size_cache, + ggml_row_size(cache_local->type, n_embd_head_qk_rope)); + cb(kv_cache_nope, "kv_cache_nope_pp", il_id); + + ggml_tensor * kv_cache_rope_view = ggml_view_3d(ctx0, cache_local, + n_embd_head_qk_rope, n_kv, 1, + row_size_cache, cache_local->nb[2], 0); + cb(kv_cache_rope_view, "kv_cache_rope_pp", il_id); + + // Hadamard cache was applied per 64-block during write; un-Hadamard the + // read views so the materialize mul_mats see the original latents. Hadamard + // requires F32 input, so dequantize the cache views first when the cache is + // quantized. Hadamard is its own inverse (the impl handles the scale). + if (cparams.k_cache_hadamard) { + ggml_tensor * kn_f32 = kv_cache_nope->type == GGML_TYPE_F32 + ? kv_cache_nope + : ggml_cast(ctx0, kv_cache_nope, GGML_TYPE_F32); + ggml_tensor * kr_f32 = kv_cache_rope_view->type == GGML_TYPE_F32 + ? kv_cache_rope_view + : ggml_cast(ctx0, kv_cache_rope_view, GGML_TYPE_F32); + kv_cache_nope = ggml_hadamard(ctx0, kn_f32, 64); + kv_cache_rope_view = ggml_hadamard(ctx0, kr_f32, 64); + } + + // CUDA quantized-cache + REPEAT/CONCAT/CPY has known issues, so force F16 here. + const auto kv_type = GGML_TYPE_F16; + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; + repeater.ne[1] = n_kv; + repeater.ne[2] = n_head_local; + repeater.ne[3] = 1; + ggml_tensor * k_rope_rep; + if (kv_cache_rope_view->type == kv_type) { + k_rope_rep = ggml_repeat(ctx0, kv_cache_rope_view, &repeater); + } else { + auto kv_rope_f16 = ggml_cast(ctx0, kv_cache_rope_view, kv_type); + k_rope_rep = ggml_repeat(ctx0, kv_rope_f16, &repeater); + } + cb(k_rope_rep, "k_rope_rep_pp", il_id); + + // V: wv_b_local viewed as 2D [kv_lora_rank, n_head_local * n_embd_head_v]. + // Per-rank, no cross-device transfer per call. + auto wv_b_2d = ggml_reshape_2d(ctx0, wv_b_local_pp, + kv_lora_rank, n_head_local * n_embd_head_v); + ggml_tensor * v_2d = ggml_mul_mat(ctx0, wv_b_2d, kv_cache_nope); + cb(v_2d, "v_2d_pp", il_id); + ggml_tensor * v_f32 = ggml_view_3d(ctx0, v_2d, + n_embd_head_v, n_kv, n_head_local, + v_2d->nb[1], + n_embd_head_v * v_2d->nb[0], + 0); + + // wk_b_pp is transpose(wk_b) pre-materialized in llm_prepare_mla. + // Shape: [kv_lora_rank, n_embd_head_qk_nope, n_head_local]. + auto wk_b_pp_split = (const ggml_split_tensor_t *)model.layers[il].wk_b_pp->extra; + GGML_ASSERT(wk_b_pp_split); + ggml_tensor * wk_b_pp_local = wk_b_pp_split->splits[id]; + GGML_ASSERT(wk_b_pp_local); + ggml_tensor * wk_b_T_2d = ggml_reshape_2d(ctx0, wk_b_pp_local, + kv_lora_rank, n_head_local * n_embd_head_qk_nope); + ggml_tensor * k_nope_2d = ggml_mul_mat(ctx0, wk_b_T_2d, kv_cache_nope); + cb(k_nope_2d, "k_nope_2d_pp", il_id); + ggml_tensor * k_nope_f32 = ggml_view_3d(ctx0, k_nope_2d, + n_embd_head_qk_nope, n_kv, n_head_local, + k_nope_2d->nb[1], + n_embd_head_qk_nope * k_nope_2d->nb[0], + 0); + + ggml_tensor * v = ggml_cast(ctx0, v_f32, kv_type); + ggml_tensor * k_nope = ggml_cast(ctx0, k_nope_f32, kv_type); + ggml_build_forward_expand(gf, v); + ggml_build_forward_expand(gf, k_nope); + + ggml_tensor * k = ggml_concat(ctx0, k_rope_rep, k_nope, 0); + ggml_build_forward_expand(gf, k); + cb(k, "k_full_pp", il_id); + + ggml_tensor * q = ggml_concat(ctx0, q_rope, q_nope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + ggml_build_forward_expand(gf, q); + cb(q, "q_concat_pp", il_id); + + ggml_tensor * kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, + kq_scale, hparams.f_max_alibi_bias, 0.f); + if (use_f32_attn_precision || q->ne[1] <= 8) { + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + } + cb(kqv, "kqv_pp", il_id); + + kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens); + } else { + // Absorb path: FlashMLA-3 with the compressed latent cache, then project via wv_b. + auto wk_b_split = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra; + GGML_ASSERT(wk_b_split); + ggml_tensor * wk_b_local = wk_b_split->splits[id]; + + ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b_local, q_nope_perm); + + ggml_tensor * q_combined = ggml_concat(ctx0, + ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0); + if (cparams.k_cache_hadamard) { + q_combined = ggml_hadamard(ctx0, q_combined, 64); + } + + // FlashMLA-3 path: K = kv_cache (full latent + rope), V = kv_cache_lora (latent only) + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, cache_local, + kv_lora_rank, n_kv, + row_size_cache, + ggml_row_size(cache_local->type, n_embd_head_qk_rope)); + cb(kv_cache_lora, "kv_cache_lora", il_id); + + ggml_tensor * kqv_compressed = ggml_flash_attn_ext(ctx0, + q_combined, kv_cache, kv_cache_lora, KQ_mask, + kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv_compressed, "kqv_compressed", il_id); + if (use_f32_attn_precision) { + ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32); + } + if (cparams.k_cache_hadamard) { + kqv_compressed = ggml_hadamard(ctx0, kqv_compressed, 64); + } + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + + auto wv_b_split = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra; + GGML_ASSERT(wv_b_split); + ggml_tensor * wv_b_local = wv_b_split->splits[id]; + + ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_local, kqv_compressed); + if (n_tokens > 1) { + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + } + kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens); } - // FlashMLA-3 path: K = kv_cache (full latent + rope), V = kv_cache_lora (latent only) - ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, cache_local, - kv_lora_rank, n_kv, - row_size_cache, - ggml_row_size(cache_local->type, n_embd_head_qk_rope)); - cb(kv_cache_lora, "kv_cache_lora", il_id); - - ggml_tensor * kqv_compressed = ggml_flash_attn_ext(ctx0, - q_combined, kv_cache, kv_cache_lora, KQ_mask, - kq_scale, hparams.f_max_alibi_bias, 0.f); - cb(kqv_compressed, "kqv_compressed", il_id); - if (use_f32_attn_precision) { - ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32); - } - if (cparams.k_cache_hadamard) { - kqv_compressed = ggml_hadamard(ctx0, kqv_compressed, 64); - } - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - - auto wv_b_split = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra; - GGML_ASSERT(wv_b_split); - ggml_tensor * wv_b_local = wv_b_split->splits[id]; - - ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_local, kqv_compressed); - if (n_tokens > 1) { - kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); - } - ggml_tensor * kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens); - ggml_tensor * partial = llm_build_lora_mm(lctx, ctx0, wo_split->splits[id], kqv_2d); // Fold residual into the first non-skipped rank so the reduce result includes it. @@ -677,7 +795,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { // whether to use n_tokens as the matrix dimension during multiplication or n_head // n_tokens is higher during prompt processing, this allows to optimize for this case - bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head; + bool pp_opt = n_tokens >= 128 && lctx.cparams.mla_attn > 1; auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ? ggml_rope_cache(ctx0, inp_pos, nullptr, n_rot, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -692,7 +810,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { if (is_tp_layer) { cur = build_deepseek2_tp_attention(gf, il, inpL, KQ_mask, inp_pos, rope_cache, kq_scale, attn_factor_scaled, - use_f32_attn_precision, is_lite); + use_f32_attn_precision, is_lite, pp_opt); } else { cur = build_deepseek2_layer_attention(gf, il, inpL, KQ_mask, inp_pos, rope_cache, kq_scale, attn_factor_scaled, diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 55b29206..ab6129ca 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -265,7 +265,8 @@ struct llm_build_context { ggml_tensor * rope_cache, float kq_scale, float attn_factor_scaled, bool use_f32_attn_precision, - bool is_lite); + bool is_lite, + bool pp_opt); ggml_tensor * build_deepseek2_layer_attention( ggml_cgraph * gf, int il, diff --git a/src/llama-model.h b/src/llama-model.h index c46c2378..14e2d44a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -177,6 +177,9 @@ struct llama_layer { struct ggml_tensor * wkq_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; + // wk_b in pp_opt-favoring layout [kv_lora_rank, qk_nope, n_head], serialized + // as "attn_kv_b.weight". Materialized under -sm graph + mla>1; mla=1 skips. + struct ggml_tensor * wk_b_pp = nullptr; struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; @@ -224,6 +227,7 @@ struct llama_layer { llama_split_tensor split_wq_b; llama_split_tensor split_wkv_a_mqa; llama_split_tensor split_wk_b; + llama_split_tensor split_wk_b_pp; llama_split_tensor split_wv_b; llama_split_tensor split_attn_q_a_norm; llama_split_tensor split_attn_kv_a_norm; @@ -381,11 +385,13 @@ struct llama_layer { struct llama_layer_nextn nextn; std::unique_ptr computed_wk_b; + std::unique_ptr computed_wk_b_pp; std::unique_ptr computed_wv_b; std::unique_ptr computed_wkv_b; // Per-device replicas of computed wk_b/wv_b (-sm graph). Buffers owned via model.bufs. std::vector> computed_wk_b_replicas; + std::vector> computed_wk_b_pp_replicas; std::vector> computed_wv_b_replicas; }; diff --git a/src/llama.cpp b/src/llama.cpp index 69ee1066..bea371ee 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -859,7 +859,8 @@ static bool llama_kv_cache_init( if (is_mla_attn) { bool have_wkv_b = true; for (auto& l : model.layers) { - if (!l.wkv_b) { + // Under -sm graph mla>1, wk_b_pp (attn_kv_b) substitutes for wkv_b. + if (!l.wkv_b && !l.wk_b_pp) { have_wkv_b = false; break; } @@ -2285,7 +2286,10 @@ static void llm_prepare_mla(llama_model & model, int mla) { } } } - auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float); + // tensor_data layout: [wk_b_f32 | wk_b_f32_t | wk_b | wk_b_pp]. wk_b_pp slot is only + // populated under -sm graph/attn (pp_opt-favoring layout); allocated unconditionally + // for simplicity (one max_wk_size buffer). + auto context_size = 2*max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float); context_size *= 2; // just in case; std::vector wkv_buffer; if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size); @@ -2296,7 +2300,7 @@ static void llm_prepare_mla(llama_model & model, int mla) { ggml_init_params params{context_size, nullptr, true}; auto ctx = ggml_init(params); auto graph = ggml_new_graph_custom(ctx, 8, false); - std::vector tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size); + std::vector tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + 2*max_wk_size); for (int il = 0; il < n_layer; ++il) { auto& l = model.layers[il]; if (l.wk_b) continue; @@ -2375,8 +2379,15 @@ static void llm_prepare_mla(llama_model & model, int mla) { const size_t slice_bytes = (size_t)n_head_local * head_block_bytes; auto dev_buft = ggml_backend_buffer_get_type(wo_split->splits[id]->buffer); auto dev_buf = ggml_backend_buft_alloc_buffer(dev_buft, slice_bytes); + if (!dev_buf) { + throw std::runtime_error("Failed to allocate per-rank buffer for " + tname); + } ggml_backend_buffer_set_usage(dev_buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); model.bufs.push_back(dev_buf); + // Intentionally not updating mem_used[id] here: llm_prepare_mla + // runs post-load, after distribute_mla_tensors has completed its + // allocation rounds, so no downstream allocator consults mem_used + // anymore. replicas[id] = std::make_unique(*source); auto rep = replicas[id].get(); @@ -2429,6 +2440,32 @@ static void llm_prepare_mla(llama_model & model, int mla) { return computed.get(); }; + // pp_opt-favoring wk_b_pp = quantize(wk_b_f32) directly (absorb-favoring + // wk_b above is its transpose). Not gated on mla — wk_b_pp shares wk_b_f32 + // with the wk_b synthesis above and skipping it breaks absorb on some + // quant combinations. Second pass below gates the mla=1 saving for + // GGUFs that ship wk_b directly. + if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) { + ggml_graph_clear(graph); + auto wk_b_pp = ggml_cast(ctx, wk_b_f32, new_type); + wk_b_pp->data = (char *)wk_b->data + ggml_nbytes(wk_b); + // tensor_data layout is [wk_b_f32 | wk_b_f32_t | wk_b | wk_b_pp]. wk_b and + // wk_b_pp are the same quant type and same shape, so wk_b_pp's slot is + // exactly one wk_b-sized block past wk_b. The tensor_data vector was sized + // 2*F32 + 2*max_wk_size, which covers both wk_b and wk_b_pp. + GGML_ASSERT((char *)wk_b_pp->data + ggml_nbytes(wk_b_pp) <= + (char *)tensor_data.data() + tensor_data.size()); + ggml_build_forward_expand(graph, wk_b_pp); + auto plan_pp = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan_pp.work_size > work_data.size()) work_data.resize(plan_pp.work_size); + plan_pp.work_data = work_data.data(); + auto status_pp = ggml_graph_compute(graph, &plan_pp); + if (status_pp != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute attn_kv_b"); + auto name_pp = std::string{"blk."} + std::to_string(il) + ".attn_kv_b.weight"; + l.wk_b_pp = materialize(wk_b_pp, l.computed_wk_b_pp, l.computed_wk_b_pp_replicas, l.split_wk_b_pp, name_pp); + ggml_graph_clear(graph); + } + l.wk_b = materialize(wk_b, l.computed_wk_b, l.computed_wk_b_replicas, l.split_wk_b, name); ggml_graph_clear(graph); @@ -2450,6 +2487,192 @@ static void llm_prepare_mla(llama_model & model, int mla) { } ggml_free(ctx); } + + // Second pass: for layers where wk_b came from the GGUF directly, produce + // wk_b_pp here. Only under -sm graph/attn AND mla > 1; mla=1 skips pp_opt. + if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && mla > 1) { + int n_pp_to_compute = 0; + for (auto & l : model.layers) { + if (l.wk_b && !l.wk_b_pp) ++n_pp_to_compute; + } + if (n_pp_to_compute > 0) { + const uint32_t n_embd_head_qk_nope_pp = hparams.n_embd_head_k(0) - hparams.n_rot; + const uint32_t kv_lora_rank_pp = hparams.n_lora_kv; + const int32_t n_head_pp = hparams.n_head(0); + + size_t max_wk_size_pp = 0; + for (auto & l : model.layers) { + if (l.wk_b && !l.wk_b_pp) { + max_wk_size_pp = std::max(max_wk_size_pp, ggml_nbytes(l.wk_b)); + } + } + auto context_size_pp = 4 * max_wk_size_pp + + 4 * (size_t)n_embd_head_qk_nope_pp * kv_lora_rank_pp * n_head_pp * sizeof(float); + context_size_pp *= 2; + + std::vector work_data_pp; + // Hoist the full-wk_b assembly buffer outside the per-layer loop so we + // don't re-allocate ~max_wk_size_pp bytes per layer. + std::vector full_wk_b_host_buf(max_wk_size_pp); + // tensor_data_pp holds, in order: [wk_b_f32 | wk_b_pp_f32 | wk_b_pp_q]. + // F32 slots size by n_embd_head_qk_nope * kv_lora_rank * n_head * 4; the + // requantized slot fits in max_wk_size_pp. After the N->1 graph compute + // consolidation we no longer need a second wk_b-sized slot. + std::vector tensor_data_pp( + 2 * (size_t)n_embd_head_qk_nope_pp * kv_lora_rank_pp * n_head_pp * sizeof(float) + + max_wk_size_pp); + + ggml_init_params params_pp{context_size_pp, nullptr, true}; + auto ctx_pp = ggml_init(params_pp); + auto graph_pp = ggml_new_graph_custom(ctx_pp, 8, false); + LLAMA_LOG_INFO("============ %s: need to compute %d attn_kv_b tensors\n", __func__, n_pp_to_compute); + + for (int il = 0; il < n_layer; ++il) { + auto & l = model.layers[il]; + if (!l.wk_b || l.wk_b_pp) continue; + + // Under -sm graph/attn (the outer block's gate), distribute_mla_tensors + // always populates l.wk_b->extra and l.wo->extra. If either is missing, + // we're in a degenerate config and pp_opt falls back to the runtime + // transpose in build_deepseek2.cpp; skip here. + if (!l.wo || !l.wo->extra || !l.wk_b->extra) continue; + + // Per-rank wk_b slices: each lives on a single device as a regular CUDA + // tensor (not the split-buffer wrapper which lacks a get_tensor impl for + // split_dim=2). Read each rank's slice independently. + auto wk_b_split = (const ggml_split_tensor_t *)l.wk_b->extra; + auto wo_split = (const ggml_split_tensor_t *)l.wo->extra; + const int n_device = wo_split->n_device; + + auto name = std::string{"blk."} + std::to_string(il) + ".attn_kv_b.weight"; + + // Build a placeholder wk_b_pp_q on host with the full [kv_lora_rank, qk_nope, n_head] + // shape (only used as a template for cloning per-rank metadata; no data filled). + ggml_tensor wk_b_pp_template = *l.wk_b; + wk_b_pp_template.ne[0] = (int64_t)kv_lora_rank_pp; + wk_b_pp_template.ne[1] = (int64_t)n_embd_head_qk_nope_pp; + wk_b_pp_template.ne[2] = (int64_t)n_head_pp; + wk_b_pp_template.nb[0] = ggml_type_size(l.wk_b->type); + wk_b_pp_template.nb[1] = wk_b_pp_template.nb[0] * (wk_b_pp_template.ne[0] / ggml_blck_size(l.wk_b->type)); + wk_b_pp_template.nb[2] = wk_b_pp_template.nb[1] * wk_b_pp_template.ne[1]; + wk_b_pp_template.nb[3] = wk_b_pp_template.nb[2] * wk_b_pp_template.ne[2]; + + l.computed_wk_b_pp = std::make_unique(wk_b_pp_template); + l.computed_wk_b_pp->buffer = nullptr; + l.computed_wk_b_pp->data = nullptr; + l.computed_wk_b_pp->op = GGML_OP_NONE; + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b_pp->src[j] = nullptr; + ggml_set_name(l.computed_wk_b_pp.get(), name.c_str()); + + l.computed_wk_b_pp_replicas.resize(n_device); + l.split_wk_b_pp.tensor_splits.assign(n_device, nullptr); + const size_t per_head_pp_bytes = wk_b_pp_template.nb[2]; + + // Build head_offsets[] from per-rank ne[2]; matches the layout used by + // prepare_split_tensors(split_dim=2) on wk_b so that head_offsets[id] + // points to the start of rank id's head range in the full wk_b layout. + std::vector head_offsets(n_device + 1, 0); + for (int id = 0; id < n_device; ++id) { + int n_h_id = 0; + if (wk_b_split->splits[id]) { + n_h_id = (int)wk_b_split->splits[id]->ne[2]; + } + head_offsets[id + 1] = head_offsets[id] + n_h_id; + } + + // Read all per-rank wk_b slices into the hoisted host buffer ordered + // by head_offset. The data layout on disk is per-head contiguous, so + // sequential rank reads at byte offset head_offsets[id] * per_head_in_bytes + // reconstitute the original full wk_b on host. + const size_t per_head_in_bytes = l.wk_b->nb[2]; + const size_t full_wk_b_nbytes = (size_t)head_offsets[n_device] * per_head_in_bytes; + GGML_ASSERT(full_wk_b_nbytes <= full_wk_b_host_buf.size()); + for (int id = 0; id < n_device; ++id) { + if (!wk_b_split->splits[id]) continue; + auto wk_b_rank = wk_b_split->splits[id]; + const size_t rank_nbytes = ggml_nbytes(wk_b_rank); + const size_t byte_offset = (size_t)head_offsets[id] * per_head_in_bytes; + GGML_ASSERT(byte_offset + rank_nbytes <= full_wk_b_nbytes); + ggml_backend_tensor_get(wk_b_rank, full_wk_b_host_buf.data() + byte_offset, 0, rank_nbytes); + } + + // ONE graph compute on the full assembled wk_b: dequant -> transpose -> + // requant. Per-rank slicing happens on the resulting host buffer. + ggml_tensor full_host = *l.wk_b; + full_host.data = full_wk_b_host_buf.data(); + auto f_f32 = ggml_cast(ctx_pp, &full_host, GGML_TYPE_F32); + f_f32->data = tensor_data_pp.data(); + auto f_view = ggml_transpose(ctx_pp, f_f32); + auto f_cont = ggml_cont(ctx_pp, f_view); + f_cont->data = (char *)f_f32->data + ggml_nbytes(f_f32); + auto f_q = ggml_cast(ctx_pp, f_cont, l.wk_b->type); + f_q->data = (char *)f_cont->data + ggml_nbytes(f_cont); + ggml_build_forward_expand(graph_pp, f_q); + auto plan = ggml_graph_plan(graph_pp, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data_pp.size()) work_data_pp.resize(plan.work_size); + plan.work_data = work_data_pp.data(); + auto status = ggml_graph_compute(graph_pp, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute attn_kv_b"); + ggml_graph_clear(graph_pp); + + // Per-rank upload: f_q is the full wk_b_pp in [kv_lora_rank, qk_nope, + // n_head] order. Per-head stride matches per_head_pp_bytes by construction. + for (int id = 0; id < n_device; ++id) { + if (!wo_split->splits[id] || !wo_split->splits[id]->buffer) continue; + if (!wk_b_split->splits[id]) continue; + const int n_head_local = (int)wk_b_split->splits[id]->ne[2]; + if (n_head_local <= 0) continue; + + const size_t slice_bytes = (size_t)n_head_local * per_head_pp_bytes; + auto dev_buft = ggml_backend_buffer_get_type(wo_split->splits[id]->buffer); + auto dev_buf = ggml_backend_buft_alloc_buffer(dev_buft, slice_bytes); + if (!dev_buf) { + throw std::runtime_error("Failed to allocate per-rank buffer for " + name); + } + ggml_backend_buffer_set_usage(dev_buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + model.bufs.push_back(dev_buf); + // Same mem_used note as Path A: distribute_mla_tensors has already + // closed its books before llm_prepare_mla runs. + + l.computed_wk_b_pp_replicas[id] = std::make_unique(wk_b_pp_template); + auto rep = l.computed_wk_b_pp_replicas[id].get(); + rep->ne[2] = n_head_local; + rep->nb[3] = rep->nb[2] * (size_t)rep->ne[2]; + rep->buffer = dev_buf; + rep->data = ggml_backend_buffer_get_base(dev_buf); + rep->op = GGML_OP_NONE; + for (int j = 0; j < GGML_MAX_SRC; ++j) rep->src[j] = nullptr; + rep->view_src = nullptr; + rep->view_offs = 0; + rep->extra = nullptr; + ggml_set_name(rep, (name + "." + std::to_string(id)).c_str()); + + const size_t byte_offset = (size_t)head_offsets[id] * per_head_pp_bytes; + ggml_backend_tensor_set(rep, (char *)f_q->data + byte_offset, 0, slice_bytes); + if (ggml_backend_buffer_is_host(rep->buffer)) { + iqk_modify_tensor(rep); + } + l.split_wk_b_pp.tensor_splits[id] = rep; + } + + l.split_wk_b_pp.ggml.n_device = n_device; + l.split_wk_b_pp.ggml.split_dim = 2; + l.split_wk_b_pp.ggml.splits = l.split_wk_b_pp.tensor_splits.data(); + l.computed_wk_b_pp->extra = (void *)&l.split_wk_b_pp.ggml; + model.tensors_by_name.push_back(std::make_pair(name, l.computed_wk_b_pp.get())); + l.wk_b_pp = l.computed_wk_b_pp.get(); + + printf("Computed %s as %d x %d x %d of type %s, split across %d devices on dim=2\n", + name.c_str(), + (int)l.computed_wk_b_pp->ne[0], + (int)l.computed_wk_b_pp->ne[1], + (int)l.computed_wk_b_pp->ne[2], + ggml_type_name(l.computed_wk_b_pp->type), n_device); + } + ggml_free(ctx_pp); + } + } + if (mla == 1 || model.split_mode == LLAMA_SPLIT_MODE_GRAPH) return; n_to_compute = 0;