mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
MLA TP prompt processing optimisation (#1841)
* MLA TP prompt processing optimisation Adds a per-rank prompt-processing path to build_deepseek2_tp_attention that materialises K/V from the compressed latent cache and runs a standard flash_attn instead of the FlashMLA-3 absorb kernel the TP attention currently uses for all batch sizes. Affects MLA archs under -sm graph (DEEPSEEK2, GLM_DSA, MISTRAL4). Gated on n_tokens >= 128 (set by caller) AND n_kv >= 1024. Below either threshold the absorb path runs unchanged. Token generation takes the absorb path; only prompt processing at non-trivial context materialises. A second piece pre-computes wk_b in a pp_opt-favouring orientation (wk_b_pp: [kv_lora_rank, qk_nope, n_head]) at llm_prepare_mla time, so the per-PP-call materialise can mul_mat against the latent cache directly without an F16 cast + permute + ggml_cont on wk_b each call. Path A (wkv_b in GGUF) and Path B (only wk_b/wv_b in GGUF) both populate wk_b_pp through the standard per-rank replica setup. Measured on 8x RTX 3090, -sm graph -mla 2 -fa on: DSV2.5 IQ2_XS c=8k ub=2048 PP +51% to +60% GLM-4.7-Flash IQ4_XS c=32k ub=2048 PP -6% (PP@0) to +77% (PP@30720) GLM-5.1 IQ1_S q4_0 c=16k ub=2048 PP +5% to +9% PPL parity within +/-0.2 noise (DSV2.5 bit-identical 5.3917, GLM-4.7 8.83 vs 8.96, GLM-5.1 6.96 vs 7.00). Token-generation throughput unchanged within noise. Compute buffer at init: DSV2.5 -54 MiB total (allocator noise) GLM-4.7-Flash +1042 MiB total (~+173 MiB per non-output device) GLM-5.1 0 (MoE intermediates dominate) * MLA TP: respect mla=1 vs mla=3 distinction, rename attn_k_b_pp -> attn_kv_b ikawrakow/ik_llama.cpp#1841 review feedback: the pp_opt path lost the intended trade-off where mla=1 forgoes pp_opt to save VRAM and mla=3 pays the wk_b_pp tensor cost for faster long-context PP. - llm_prepare_mla second pass: gate wk_b_pp synthesis on mla > 1. Models that ship wk_b in their GGUF (mainline format) no longer allocate the pp_opt-favoring K weight under mla=1. - llm_prepare_mla first pass (wk_b synthesis from wkv_b): keep unconditional under -sm graph. The wk_b_pp materialization here shares the wk_b_f32 intermediate with the wk_b synthesis above, and isolating just the wk_b_pp branch leaves the synthesized wk_b in a state that makes the absorb path produce inf on some quant combos (DSV2.5 IQ2_XS). Trade: the synthesized-wkv_b path still pays the wk_b_pp allocation under mla=1, but the bigger compute-buffer saving (no pp_opt branch at runtime) still applies. - build_deepseek2 outer pp_opt: include cparams.mla_attn > 1 in the pp_opt definition itself, so mla=1 is bypassed throughout (TP and non-TP attention paths). - build_deepseek2 tp pp_opt: require wk_b_pp present. Drop the dead runtime wk_b transpose fallback (unreachable now that wk_b_pp is guaranteed when tp_pp_opt fires). - llama_kv_cache_init: have_wkv_b probe now treats wk_b_pp (attn_kv_b) as equivalent to wkv_b for the purposes of allowing mla>1 to stay put. Without this, -sm graph models that have wk_b/wv_b separately in the GGUF (no combined wkv_b) would silently downgrade to mla=1. - Rename the synthesized tensor "attn_k_b_pp.weight" -> "attn_kv_b.weight" to match the mainline naming ik uses. GLM-5.1 in particular benefits: its mla=3 PP improvement over mla=1 is negligible on this arch (~0.4% in our sweeps), so users save the runtime cost by sticking to mla=1.
This commit is contained in:
parent
40254a51da
commit
dd67a9fb24
@ -13,7 +13,8 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention(
|
||||
ggml_tensor * rope_cache,
|
||||
float kq_scale, float attn_factor_scaled,
|
||||
bool use_f32_attn_precision,
|
||||
bool is_lite) {
|
||||
bool is_lite,
|
||||
bool pp_opt) {
|
||||
if (!lctx.cparams.flash_attn || lctx.cparams.mla_attn < 1) {
|
||||
GGML_ABORT("-sm graph for MLA archs (DEEPSEEK2/GLM_DSA/MISTRAL4) requires -fa on and -mla >= 1. "
|
||||
"Got mla_attn=%d, flash_attn=%d.",
|
||||
@ -140,49 +141,166 @@ ggml_tensor * llm_build_context::build_deepseek2_tp_attention(
|
||||
row_size_cache, 0);
|
||||
cb(kv_cache, "kv_cache", il_id);
|
||||
|
||||
// wk_b is split per-head (split_dim=2); each rank's tensor already contains only its n_head_local heads.
|
||||
auto wk_b_split = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra;
|
||||
GGML_ASSERT(wk_b_split);
|
||||
ggml_tensor * wk_b_local = wk_b_split->splits[id];
|
||||
// pp_opt (mla > 1, n_tokens >= 128, n_kv >= k_pp_opt_min_kv): materialize
|
||||
// per-rank K/V from the latent cache and use standard flash_attn instead of
|
||||
// FlashMLA-3 absorb.
|
||||
constexpr int k_pp_opt_min_kv = 1024;
|
||||
const bool tp_pp_opt = pp_opt
|
||||
&& (int)n_kv >= k_pp_opt_min_kv
|
||||
&& model.layers[il].wk_b
|
||||
&& model.layers[il].wv_b
|
||||
&& model.layers[il].wk_b_pp;
|
||||
|
||||
ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b_local, q_nope_perm);
|
||||
ggml_tensor * kqv_2d;
|
||||
|
||||
ggml_tensor * q_combined = ggml_concat(ctx0,
|
||||
ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0);
|
||||
if (cparams.k_cache_hadamard) {
|
||||
q_combined = ggml_hadamard(ctx0, q_combined, 64);
|
||||
if (tp_pp_opt) {
|
||||
// Per-rank wk_b/wv_b slices already exist from distribute_mla_tensors:
|
||||
// wk_b_local_pp: [n_embd_head_qk_nope, kv_lora_rank, n_head_local]
|
||||
// wv_b_local_pp: [kv_lora_rank, n_embd_head_v, n_head_local]
|
||||
auto wk_b_pp_split_raw = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra;
|
||||
auto wv_b_pp_split_raw = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra;
|
||||
GGML_ASSERT(wk_b_pp_split_raw && wv_b_pp_split_raw);
|
||||
ggml_tensor * wk_b_local_pp = wk_b_pp_split_raw->splits[id];
|
||||
ggml_tensor * wv_b_local_pp = wv_b_pp_split_raw->splits[id];
|
||||
|
||||
ggml_tensor * kv_cache_nope = ggml_view_2d(ctx0, cache_local,
|
||||
kv_lora_rank, n_kv,
|
||||
row_size_cache,
|
||||
ggml_row_size(cache_local->type, n_embd_head_qk_rope));
|
||||
cb(kv_cache_nope, "kv_cache_nope_pp", il_id);
|
||||
|
||||
ggml_tensor * kv_cache_rope_view = ggml_view_3d(ctx0, cache_local,
|
||||
n_embd_head_qk_rope, n_kv, 1,
|
||||
row_size_cache, cache_local->nb[2], 0);
|
||||
cb(kv_cache_rope_view, "kv_cache_rope_pp", il_id);
|
||||
|
||||
// Hadamard cache was applied per 64-block during write; un-Hadamard the
|
||||
// read views so the materialize mul_mats see the original latents. Hadamard
|
||||
// requires F32 input, so dequantize the cache views first when the cache is
|
||||
// quantized. Hadamard is its own inverse (the impl handles the scale).
|
||||
if (cparams.k_cache_hadamard) {
|
||||
ggml_tensor * kn_f32 = kv_cache_nope->type == GGML_TYPE_F32
|
||||
? kv_cache_nope
|
||||
: ggml_cast(ctx0, kv_cache_nope, GGML_TYPE_F32);
|
||||
ggml_tensor * kr_f32 = kv_cache_rope_view->type == GGML_TYPE_F32
|
||||
? kv_cache_rope_view
|
||||
: ggml_cast(ctx0, kv_cache_rope_view, GGML_TYPE_F32);
|
||||
kv_cache_nope = ggml_hadamard(ctx0, kn_f32, 64);
|
||||
kv_cache_rope_view = ggml_hadamard(ctx0, kr_f32, 64);
|
||||
}
|
||||
|
||||
// CUDA quantized-cache + REPEAT/CONCAT/CPY has known issues, so force F16 here.
|
||||
const auto kv_type = GGML_TYPE_F16;
|
||||
|
||||
ggml_tensor repeater;
|
||||
repeater.ne[0] = n_embd_head_qk_rope;
|
||||
repeater.ne[1] = n_kv;
|
||||
repeater.ne[2] = n_head_local;
|
||||
repeater.ne[3] = 1;
|
||||
ggml_tensor * k_rope_rep;
|
||||
if (kv_cache_rope_view->type == kv_type) {
|
||||
k_rope_rep = ggml_repeat(ctx0, kv_cache_rope_view, &repeater);
|
||||
} else {
|
||||
auto kv_rope_f16 = ggml_cast(ctx0, kv_cache_rope_view, kv_type);
|
||||
k_rope_rep = ggml_repeat(ctx0, kv_rope_f16, &repeater);
|
||||
}
|
||||
cb(k_rope_rep, "k_rope_rep_pp", il_id);
|
||||
|
||||
// V: wv_b_local viewed as 2D [kv_lora_rank, n_head_local * n_embd_head_v].
|
||||
// Per-rank, no cross-device transfer per call.
|
||||
auto wv_b_2d = ggml_reshape_2d(ctx0, wv_b_local_pp,
|
||||
kv_lora_rank, n_head_local * n_embd_head_v);
|
||||
ggml_tensor * v_2d = ggml_mul_mat(ctx0, wv_b_2d, kv_cache_nope);
|
||||
cb(v_2d, "v_2d_pp", il_id);
|
||||
ggml_tensor * v_f32 = ggml_view_3d(ctx0, v_2d,
|
||||
n_embd_head_v, n_kv, n_head_local,
|
||||
v_2d->nb[1],
|
||||
n_embd_head_v * v_2d->nb[0],
|
||||
0);
|
||||
|
||||
// wk_b_pp is transpose(wk_b) pre-materialized in llm_prepare_mla.
|
||||
// Shape: [kv_lora_rank, n_embd_head_qk_nope, n_head_local].
|
||||
auto wk_b_pp_split = (const ggml_split_tensor_t *)model.layers[il].wk_b_pp->extra;
|
||||
GGML_ASSERT(wk_b_pp_split);
|
||||
ggml_tensor * wk_b_pp_local = wk_b_pp_split->splits[id];
|
||||
GGML_ASSERT(wk_b_pp_local);
|
||||
ggml_tensor * wk_b_T_2d = ggml_reshape_2d(ctx0, wk_b_pp_local,
|
||||
kv_lora_rank, n_head_local * n_embd_head_qk_nope);
|
||||
ggml_tensor * k_nope_2d = ggml_mul_mat(ctx0, wk_b_T_2d, kv_cache_nope);
|
||||
cb(k_nope_2d, "k_nope_2d_pp", il_id);
|
||||
ggml_tensor * k_nope_f32 = ggml_view_3d(ctx0, k_nope_2d,
|
||||
n_embd_head_qk_nope, n_kv, n_head_local,
|
||||
k_nope_2d->nb[1],
|
||||
n_embd_head_qk_nope * k_nope_2d->nb[0],
|
||||
0);
|
||||
|
||||
ggml_tensor * v = ggml_cast(ctx0, v_f32, kv_type);
|
||||
ggml_tensor * k_nope = ggml_cast(ctx0, k_nope_f32, kv_type);
|
||||
ggml_build_forward_expand(gf, v);
|
||||
ggml_build_forward_expand(gf, k_nope);
|
||||
|
||||
ggml_tensor * k = ggml_concat(ctx0, k_rope_rep, k_nope, 0);
|
||||
ggml_build_forward_expand(gf, k);
|
||||
cb(k, "k_full_pp", il_id);
|
||||
|
||||
ggml_tensor * q = ggml_concat(ctx0, q_rope, q_nope, 0);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
ggml_build_forward_expand(gf, q);
|
||||
cb(q, "q_concat_pp", il_id);
|
||||
|
||||
ggml_tensor * kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask,
|
||||
kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
if (use_f32_attn_precision || q->ne[1] <= 8) {
|
||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||
}
|
||||
cb(kqv, "kqv_pp", il_id);
|
||||
|
||||
kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens);
|
||||
} else {
|
||||
// Absorb path: FlashMLA-3 with the compressed latent cache, then project via wv_b.
|
||||
auto wk_b_split = (const ggml_split_tensor_t *)model.layers[il].wk_b->extra;
|
||||
GGML_ASSERT(wk_b_split);
|
||||
ggml_tensor * wk_b_local = wk_b_split->splits[id];
|
||||
|
||||
ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b_local, q_nope_perm);
|
||||
|
||||
ggml_tensor * q_combined = ggml_concat(ctx0,
|
||||
ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0);
|
||||
if (cparams.k_cache_hadamard) {
|
||||
q_combined = ggml_hadamard(ctx0, q_combined, 64);
|
||||
}
|
||||
|
||||
// FlashMLA-3 path: K = kv_cache (full latent + rope), V = kv_cache_lora (latent only)
|
||||
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, cache_local,
|
||||
kv_lora_rank, n_kv,
|
||||
row_size_cache,
|
||||
ggml_row_size(cache_local->type, n_embd_head_qk_rope));
|
||||
cb(kv_cache_lora, "kv_cache_lora", il_id);
|
||||
|
||||
ggml_tensor * kqv_compressed = ggml_flash_attn_ext(ctx0,
|
||||
q_combined, kv_cache, kv_cache_lora, KQ_mask,
|
||||
kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
cb(kqv_compressed, "kqv_compressed", il_id);
|
||||
if (use_f32_attn_precision) {
|
||||
ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32);
|
||||
}
|
||||
if (cparams.k_cache_hadamard) {
|
||||
kqv_compressed = ggml_hadamard(ctx0, kqv_compressed, 64);
|
||||
}
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
|
||||
auto wv_b_split = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra;
|
||||
GGML_ASSERT(wv_b_split);
|
||||
ggml_tensor * wv_b_local = wv_b_split->splits[id];
|
||||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_local, kqv_compressed);
|
||||
if (n_tokens > 1) {
|
||||
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
|
||||
}
|
||||
kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens);
|
||||
}
|
||||
|
||||
// FlashMLA-3 path: K = kv_cache (full latent + rope), V = kv_cache_lora (latent only)
|
||||
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, cache_local,
|
||||
kv_lora_rank, n_kv,
|
||||
row_size_cache,
|
||||
ggml_row_size(cache_local->type, n_embd_head_qk_rope));
|
||||
cb(kv_cache_lora, "kv_cache_lora", il_id);
|
||||
|
||||
ggml_tensor * kqv_compressed = ggml_flash_attn_ext(ctx0,
|
||||
q_combined, kv_cache, kv_cache_lora, KQ_mask,
|
||||
kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||
cb(kqv_compressed, "kqv_compressed", il_id);
|
||||
if (use_f32_attn_precision) {
|
||||
ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32);
|
||||
}
|
||||
if (cparams.k_cache_hadamard) {
|
||||
kqv_compressed = ggml_hadamard(ctx0, kqv_compressed, 64);
|
||||
}
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
|
||||
auto wv_b_split = (const ggml_split_tensor_t *)model.layers[il].wv_b->extra;
|
||||
GGML_ASSERT(wv_b_split);
|
||||
ggml_tensor * wv_b_local = wv_b_split->splits[id];
|
||||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b_local, kqv_compressed);
|
||||
if (n_tokens > 1) {
|
||||
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
|
||||
}
|
||||
ggml_tensor * kqv_2d = ggml_reshape_2d(ctx0, kqv, n_embd_head_v * n_head_local, n_tokens);
|
||||
|
||||
ggml_tensor * partial = llm_build_lora_mm(lctx, ctx0, wo_split->splits[id], kqv_2d);
|
||||
|
||||
// Fold residual into the first non-skipped rank so the reduce result includes it.
|
||||
@ -677,7 +795,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
|
||||
// whether to use n_tokens as the matrix dimension during multiplication or n_head
|
||||
// n_tokens is higher during prompt processing, this allows to optimize for this case
|
||||
bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head;
|
||||
bool pp_opt = n_tokens >= 128 && lctx.cparams.mla_attn > 1;
|
||||
|
||||
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_rot, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@ -692,7 +810,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
if (is_tp_layer) {
|
||||
cur = build_deepseek2_tp_attention(gf, il, inpL, KQ_mask, inp_pos, rope_cache,
|
||||
kq_scale, attn_factor_scaled,
|
||||
use_f32_attn_precision, is_lite);
|
||||
use_f32_attn_precision, is_lite, pp_opt);
|
||||
} else {
|
||||
cur = build_deepseek2_layer_attention(gf, il, inpL, KQ_mask, inp_pos, rope_cache,
|
||||
kq_scale, attn_factor_scaled,
|
||||
|
||||
@ -265,7 +265,8 @@ struct llm_build_context {
|
||||
ggml_tensor * rope_cache,
|
||||
float kq_scale, float attn_factor_scaled,
|
||||
bool use_f32_attn_precision,
|
||||
bool is_lite);
|
||||
bool is_lite,
|
||||
bool pp_opt);
|
||||
|
||||
ggml_tensor * build_deepseek2_layer_attention(
|
||||
ggml_cgraph * gf, int il,
|
||||
|
||||
@ -177,6 +177,9 @@ struct llama_layer {
|
||||
struct ggml_tensor * wkq_a_mqa = nullptr;
|
||||
struct ggml_tensor * wkv_b = nullptr;
|
||||
struct ggml_tensor * wk_b = nullptr;
|
||||
// wk_b in pp_opt-favoring layout [kv_lora_rank, qk_nope, n_head], serialized
|
||||
// as "attn_kv_b.weight". Materialized under -sm graph + mla>1; mla=1 skips.
|
||||
struct ggml_tensor * wk_b_pp = nullptr;
|
||||
struct ggml_tensor * wv_b = nullptr;
|
||||
struct ggml_tensor * wq_cross = nullptr;
|
||||
struct ggml_tensor * wk_cross = nullptr;
|
||||
@ -224,6 +227,7 @@ struct llama_layer {
|
||||
llama_split_tensor split_wq_b;
|
||||
llama_split_tensor split_wkv_a_mqa;
|
||||
llama_split_tensor split_wk_b;
|
||||
llama_split_tensor split_wk_b_pp;
|
||||
llama_split_tensor split_wv_b;
|
||||
llama_split_tensor split_attn_q_a_norm;
|
||||
llama_split_tensor split_attn_kv_a_norm;
|
||||
@ -381,11 +385,13 @@ struct llama_layer {
|
||||
struct llama_layer_nextn nextn;
|
||||
|
||||
std::unique_ptr<ggml_tensor> computed_wk_b;
|
||||
std::unique_ptr<ggml_tensor> computed_wk_b_pp;
|
||||
std::unique_ptr<ggml_tensor> computed_wv_b;
|
||||
std::unique_ptr<ggml_tensor> computed_wkv_b;
|
||||
|
||||
// Per-device replicas of computed wk_b/wv_b (-sm graph). Buffers owned via model.bufs.
|
||||
std::vector<std::unique_ptr<ggml_tensor>> computed_wk_b_replicas;
|
||||
std::vector<std::unique_ptr<ggml_tensor>> computed_wk_b_pp_replicas;
|
||||
std::vector<std::unique_ptr<ggml_tensor>> computed_wv_b_replicas;
|
||||
};
|
||||
|
||||
|
||||
229
src/llama.cpp
229
src/llama.cpp
@ -859,7 +859,8 @@ static bool llama_kv_cache_init(
|
||||
if (is_mla_attn) {
|
||||
bool have_wkv_b = true;
|
||||
for (auto& l : model.layers) {
|
||||
if (!l.wkv_b) {
|
||||
// Under -sm graph mla>1, wk_b_pp (attn_kv_b) substitutes for wkv_b.
|
||||
if (!l.wkv_b && !l.wk_b_pp) {
|
||||
have_wkv_b = false;
|
||||
break;
|
||||
}
|
||||
@ -2285,7 +2286,10 @@ static void llm_prepare_mla(llama_model & model, int mla) {
|
||||
}
|
||||
}
|
||||
}
|
||||
auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float);
|
||||
// tensor_data layout: [wk_b_f32 | wk_b_f32_t | wk_b | wk_b_pp]. wk_b_pp slot is only
|
||||
// populated under -sm graph/attn (pp_opt-favoring layout); allocated unconditionally
|
||||
// for simplicity (one max_wk_size buffer).
|
||||
auto context_size = 2*max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float);
|
||||
context_size *= 2; // just in case;
|
||||
std::vector<uint8_t> wkv_buffer;
|
||||
if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size);
|
||||
@ -2296,7 +2300,7 @@ static void llm_prepare_mla(llama_model & model, int mla) {
|
||||
ggml_init_params params{context_size, nullptr, true};
|
||||
auto ctx = ggml_init(params);
|
||||
auto graph = ggml_new_graph_custom(ctx, 8, false);
|
||||
std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size);
|
||||
std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + 2*max_wk_size);
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto& l = model.layers[il];
|
||||
if (l.wk_b) continue;
|
||||
@ -2375,8 +2379,15 @@ static void llm_prepare_mla(llama_model & model, int mla) {
|
||||
const size_t slice_bytes = (size_t)n_head_local * head_block_bytes;
|
||||
auto dev_buft = ggml_backend_buffer_get_type(wo_split->splits[id]->buffer);
|
||||
auto dev_buf = ggml_backend_buft_alloc_buffer(dev_buft, slice_bytes);
|
||||
if (!dev_buf) {
|
||||
throw std::runtime_error("Failed to allocate per-rank buffer for " + tname);
|
||||
}
|
||||
ggml_backend_buffer_set_usage(dev_buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
model.bufs.push_back(dev_buf);
|
||||
// Intentionally not updating mem_used[id] here: llm_prepare_mla
|
||||
// runs post-load, after distribute_mla_tensors has completed its
|
||||
// allocation rounds, so no downstream allocator consults mem_used
|
||||
// anymore.
|
||||
|
||||
replicas[id] = std::make_unique<ggml_tensor>(*source);
|
||||
auto rep = replicas[id].get();
|
||||
@ -2429,6 +2440,32 @@ static void llm_prepare_mla(llama_model & model, int mla) {
|
||||
return computed.get();
|
||||
};
|
||||
|
||||
// pp_opt-favoring wk_b_pp = quantize(wk_b_f32) directly (absorb-favoring
|
||||
// wk_b above is its transpose). Not gated on mla — wk_b_pp shares wk_b_f32
|
||||
// with the wk_b synthesis above and skipping it breaks absorb on some
|
||||
// quant combinations. Second pass below gates the mla=1 saving for
|
||||
// GGUFs that ship wk_b directly.
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
ggml_graph_clear(graph);
|
||||
auto wk_b_pp = ggml_cast(ctx, wk_b_f32, new_type);
|
||||
wk_b_pp->data = (char *)wk_b->data + ggml_nbytes(wk_b);
|
||||
// tensor_data layout is [wk_b_f32 | wk_b_f32_t | wk_b | wk_b_pp]. wk_b and
|
||||
// wk_b_pp are the same quant type and same shape, so wk_b_pp's slot is
|
||||
// exactly one wk_b-sized block past wk_b. The tensor_data vector was sized
|
||||
// 2*F32 + 2*max_wk_size, which covers both wk_b and wk_b_pp.
|
||||
GGML_ASSERT((char *)wk_b_pp->data + ggml_nbytes(wk_b_pp) <=
|
||||
(char *)tensor_data.data() + tensor_data.size());
|
||||
ggml_build_forward_expand(graph, wk_b_pp);
|
||||
auto plan_pp = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2);
|
||||
if (plan_pp.work_size > work_data.size()) work_data.resize(plan_pp.work_size);
|
||||
plan_pp.work_data = work_data.data();
|
||||
auto status_pp = ggml_graph_compute(graph, &plan_pp);
|
||||
if (status_pp != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute attn_kv_b");
|
||||
auto name_pp = std::string{"blk."} + std::to_string(il) + ".attn_kv_b.weight";
|
||||
l.wk_b_pp = materialize(wk_b_pp, l.computed_wk_b_pp, l.computed_wk_b_pp_replicas, l.split_wk_b_pp, name_pp);
|
||||
ggml_graph_clear(graph);
|
||||
}
|
||||
|
||||
l.wk_b = materialize(wk_b, l.computed_wk_b, l.computed_wk_b_replicas, l.split_wk_b, name);
|
||||
|
||||
ggml_graph_clear(graph);
|
||||
@ -2450,6 +2487,192 @@ static void llm_prepare_mla(llama_model & model, int mla) {
|
||||
}
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
// Second pass: for layers where wk_b came from the GGUF directly, produce
|
||||
// wk_b_pp here. Only under -sm graph/attn AND mla > 1; mla=1 skips pp_opt.
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && mla > 1) {
|
||||
int n_pp_to_compute = 0;
|
||||
for (auto & l : model.layers) {
|
||||
if (l.wk_b && !l.wk_b_pp) ++n_pp_to_compute;
|
||||
}
|
||||
if (n_pp_to_compute > 0) {
|
||||
const uint32_t n_embd_head_qk_nope_pp = hparams.n_embd_head_k(0) - hparams.n_rot;
|
||||
const uint32_t kv_lora_rank_pp = hparams.n_lora_kv;
|
||||
const int32_t n_head_pp = hparams.n_head(0);
|
||||
|
||||
size_t max_wk_size_pp = 0;
|
||||
for (auto & l : model.layers) {
|
||||
if (l.wk_b && !l.wk_b_pp) {
|
||||
max_wk_size_pp = std::max(max_wk_size_pp, ggml_nbytes(l.wk_b));
|
||||
}
|
||||
}
|
||||
auto context_size_pp = 4 * max_wk_size_pp +
|
||||
4 * (size_t)n_embd_head_qk_nope_pp * kv_lora_rank_pp * n_head_pp * sizeof(float);
|
||||
context_size_pp *= 2;
|
||||
|
||||
std::vector<uint8_t> work_data_pp;
|
||||
// Hoist the full-wk_b assembly buffer outside the per-layer loop so we
|
||||
// don't re-allocate ~max_wk_size_pp bytes per layer.
|
||||
std::vector<uint8_t> full_wk_b_host_buf(max_wk_size_pp);
|
||||
// tensor_data_pp holds, in order: [wk_b_f32 | wk_b_pp_f32 | wk_b_pp_q].
|
||||
// F32 slots size by n_embd_head_qk_nope * kv_lora_rank * n_head * 4; the
|
||||
// requantized slot fits in max_wk_size_pp. After the N->1 graph compute
|
||||
// consolidation we no longer need a second wk_b-sized slot.
|
||||
std::vector<uint8_t> tensor_data_pp(
|
||||
2 * (size_t)n_embd_head_qk_nope_pp * kv_lora_rank_pp * n_head_pp * sizeof(float) +
|
||||
max_wk_size_pp);
|
||||
|
||||
ggml_init_params params_pp{context_size_pp, nullptr, true};
|
||||
auto ctx_pp = ggml_init(params_pp);
|
||||
auto graph_pp = ggml_new_graph_custom(ctx_pp, 8, false);
|
||||
LLAMA_LOG_INFO("============ %s: need to compute %d attn_kv_b tensors\n", __func__, n_pp_to_compute);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto & l = model.layers[il];
|
||||
if (!l.wk_b || l.wk_b_pp) continue;
|
||||
|
||||
// Under -sm graph/attn (the outer block's gate), distribute_mla_tensors
|
||||
// always populates l.wk_b->extra and l.wo->extra. If either is missing,
|
||||
// we're in a degenerate config and pp_opt falls back to the runtime
|
||||
// transpose in build_deepseek2.cpp; skip here.
|
||||
if (!l.wo || !l.wo->extra || !l.wk_b->extra) continue;
|
||||
|
||||
// Per-rank wk_b slices: each lives on a single device as a regular CUDA
|
||||
// tensor (not the split-buffer wrapper which lacks a get_tensor impl for
|
||||
// split_dim=2). Read each rank's slice independently.
|
||||
auto wk_b_split = (const ggml_split_tensor_t *)l.wk_b->extra;
|
||||
auto wo_split = (const ggml_split_tensor_t *)l.wo->extra;
|
||||
const int n_device = wo_split->n_device;
|
||||
|
||||
auto name = std::string{"blk."} + std::to_string(il) + ".attn_kv_b.weight";
|
||||
|
||||
// Build a placeholder wk_b_pp_q on host with the full [kv_lora_rank, qk_nope, n_head]
|
||||
// shape (only used as a template for cloning per-rank metadata; no data filled).
|
||||
ggml_tensor wk_b_pp_template = *l.wk_b;
|
||||
wk_b_pp_template.ne[0] = (int64_t)kv_lora_rank_pp;
|
||||
wk_b_pp_template.ne[1] = (int64_t)n_embd_head_qk_nope_pp;
|
||||
wk_b_pp_template.ne[2] = (int64_t)n_head_pp;
|
||||
wk_b_pp_template.nb[0] = ggml_type_size(l.wk_b->type);
|
||||
wk_b_pp_template.nb[1] = wk_b_pp_template.nb[0] * (wk_b_pp_template.ne[0] / ggml_blck_size(l.wk_b->type));
|
||||
wk_b_pp_template.nb[2] = wk_b_pp_template.nb[1] * wk_b_pp_template.ne[1];
|
||||
wk_b_pp_template.nb[3] = wk_b_pp_template.nb[2] * wk_b_pp_template.ne[2];
|
||||
|
||||
l.computed_wk_b_pp = std::make_unique<ggml_tensor>(wk_b_pp_template);
|
||||
l.computed_wk_b_pp->buffer = nullptr;
|
||||
l.computed_wk_b_pp->data = nullptr;
|
||||
l.computed_wk_b_pp->op = GGML_OP_NONE;
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b_pp->src[j] = nullptr;
|
||||
ggml_set_name(l.computed_wk_b_pp.get(), name.c_str());
|
||||
|
||||
l.computed_wk_b_pp_replicas.resize(n_device);
|
||||
l.split_wk_b_pp.tensor_splits.assign(n_device, nullptr);
|
||||
const size_t per_head_pp_bytes = wk_b_pp_template.nb[2];
|
||||
|
||||
// Build head_offsets[] from per-rank ne[2]; matches the layout used by
|
||||
// prepare_split_tensors(split_dim=2) on wk_b so that head_offsets[id]
|
||||
// points to the start of rank id's head range in the full wk_b layout.
|
||||
std::vector<int> head_offsets(n_device + 1, 0);
|
||||
for (int id = 0; id < n_device; ++id) {
|
||||
int n_h_id = 0;
|
||||
if (wk_b_split->splits[id]) {
|
||||
n_h_id = (int)wk_b_split->splits[id]->ne[2];
|
||||
}
|
||||
head_offsets[id + 1] = head_offsets[id] + n_h_id;
|
||||
}
|
||||
|
||||
// Read all per-rank wk_b slices into the hoisted host buffer ordered
|
||||
// by head_offset. The data layout on disk is per-head contiguous, so
|
||||
// sequential rank reads at byte offset head_offsets[id] * per_head_in_bytes
|
||||
// reconstitute the original full wk_b on host.
|
||||
const size_t per_head_in_bytes = l.wk_b->nb[2];
|
||||
const size_t full_wk_b_nbytes = (size_t)head_offsets[n_device] * per_head_in_bytes;
|
||||
GGML_ASSERT(full_wk_b_nbytes <= full_wk_b_host_buf.size());
|
||||
for (int id = 0; id < n_device; ++id) {
|
||||
if (!wk_b_split->splits[id]) continue;
|
||||
auto wk_b_rank = wk_b_split->splits[id];
|
||||
const size_t rank_nbytes = ggml_nbytes(wk_b_rank);
|
||||
const size_t byte_offset = (size_t)head_offsets[id] * per_head_in_bytes;
|
||||
GGML_ASSERT(byte_offset + rank_nbytes <= full_wk_b_nbytes);
|
||||
ggml_backend_tensor_get(wk_b_rank, full_wk_b_host_buf.data() + byte_offset, 0, rank_nbytes);
|
||||
}
|
||||
|
||||
// ONE graph compute on the full assembled wk_b: dequant -> transpose ->
|
||||
// requant. Per-rank slicing happens on the resulting host buffer.
|
||||
ggml_tensor full_host = *l.wk_b;
|
||||
full_host.data = full_wk_b_host_buf.data();
|
||||
auto f_f32 = ggml_cast(ctx_pp, &full_host, GGML_TYPE_F32);
|
||||
f_f32->data = tensor_data_pp.data();
|
||||
auto f_view = ggml_transpose(ctx_pp, f_f32);
|
||||
auto f_cont = ggml_cont(ctx_pp, f_view);
|
||||
f_cont->data = (char *)f_f32->data + ggml_nbytes(f_f32);
|
||||
auto f_q = ggml_cast(ctx_pp, f_cont, l.wk_b->type);
|
||||
f_q->data = (char *)f_cont->data + ggml_nbytes(f_cont);
|
||||
ggml_build_forward_expand(graph_pp, f_q);
|
||||
auto plan = ggml_graph_plan(graph_pp, std::thread::hardware_concurrency()/2);
|
||||
if (plan.work_size > work_data_pp.size()) work_data_pp.resize(plan.work_size);
|
||||
plan.work_data = work_data_pp.data();
|
||||
auto status = ggml_graph_compute(graph_pp, &plan);
|
||||
if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute attn_kv_b");
|
||||
ggml_graph_clear(graph_pp);
|
||||
|
||||
// Per-rank upload: f_q is the full wk_b_pp in [kv_lora_rank, qk_nope,
|
||||
// n_head] order. Per-head stride matches per_head_pp_bytes by construction.
|
||||
for (int id = 0; id < n_device; ++id) {
|
||||
if (!wo_split->splits[id] || !wo_split->splits[id]->buffer) continue;
|
||||
if (!wk_b_split->splits[id]) continue;
|
||||
const int n_head_local = (int)wk_b_split->splits[id]->ne[2];
|
||||
if (n_head_local <= 0) continue;
|
||||
|
||||
const size_t slice_bytes = (size_t)n_head_local * per_head_pp_bytes;
|
||||
auto dev_buft = ggml_backend_buffer_get_type(wo_split->splits[id]->buffer);
|
||||
auto dev_buf = ggml_backend_buft_alloc_buffer(dev_buft, slice_bytes);
|
||||
if (!dev_buf) {
|
||||
throw std::runtime_error("Failed to allocate per-rank buffer for " + name);
|
||||
}
|
||||
ggml_backend_buffer_set_usage(dev_buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
model.bufs.push_back(dev_buf);
|
||||
// Same mem_used note as Path A: distribute_mla_tensors has already
|
||||
// closed its books before llm_prepare_mla runs.
|
||||
|
||||
l.computed_wk_b_pp_replicas[id] = std::make_unique<ggml_tensor>(wk_b_pp_template);
|
||||
auto rep = l.computed_wk_b_pp_replicas[id].get();
|
||||
rep->ne[2] = n_head_local;
|
||||
rep->nb[3] = rep->nb[2] * (size_t)rep->ne[2];
|
||||
rep->buffer = dev_buf;
|
||||
rep->data = ggml_backend_buffer_get_base(dev_buf);
|
||||
rep->op = GGML_OP_NONE;
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) rep->src[j] = nullptr;
|
||||
rep->view_src = nullptr;
|
||||
rep->view_offs = 0;
|
||||
rep->extra = nullptr;
|
||||
ggml_set_name(rep, (name + "." + std::to_string(id)).c_str());
|
||||
|
||||
const size_t byte_offset = (size_t)head_offsets[id] * per_head_pp_bytes;
|
||||
ggml_backend_tensor_set(rep, (char *)f_q->data + byte_offset, 0, slice_bytes);
|
||||
if (ggml_backend_buffer_is_host(rep->buffer)) {
|
||||
iqk_modify_tensor(rep);
|
||||
}
|
||||
l.split_wk_b_pp.tensor_splits[id] = rep;
|
||||
}
|
||||
|
||||
l.split_wk_b_pp.ggml.n_device = n_device;
|
||||
l.split_wk_b_pp.ggml.split_dim = 2;
|
||||
l.split_wk_b_pp.ggml.splits = l.split_wk_b_pp.tensor_splits.data();
|
||||
l.computed_wk_b_pp->extra = (void *)&l.split_wk_b_pp.ggml;
|
||||
model.tensors_by_name.push_back(std::make_pair(name, l.computed_wk_b_pp.get()));
|
||||
l.wk_b_pp = l.computed_wk_b_pp.get();
|
||||
|
||||
printf("Computed %s as %d x %d x %d of type %s, split across %d devices on dim=2\n",
|
||||
name.c_str(),
|
||||
(int)l.computed_wk_b_pp->ne[0],
|
||||
(int)l.computed_wk_b_pp->ne[1],
|
||||
(int)l.computed_wk_b_pp->ne[2],
|
||||
ggml_type_name(l.computed_wk_b_pp->type), n_device);
|
||||
}
|
||||
ggml_free(ctx_pp);
|
||||
}
|
||||
}
|
||||
|
||||
if (mla == 1 || model.split_mode == LLAMA_SPLIT_MODE_GRAPH) return;
|
||||
|
||||
n_to_compute = 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user