Split mode graph for Mellum (#1920)

This commit is contained in:
Kawrakow 2026-06-04 15:20:41 +02:00 committed by GitHub
parent dc51c6f9b2
commit 4406e637b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 41 deletions

View File

@ -21,47 +21,13 @@ ggml_cgraph * llm_build_context::build_mellum() {
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k(il));
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * inpSA = inpL;
auto KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr, KQ_mask_l,
nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, is_swa ? hparams.n_swa : 0, il, true, false, true);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, nullptr,
model.layers[il].wqk, nullptr,
model.layers[il].wq, nullptr,
model.layers[il].wk, nullptr,
model.layers[il].wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : freq_base;
const float freq_scale_l = is_swa ? 1.0f : freq_scale;
const float ext_factor_l = is_swa ? 0.0f : ext_factor;
const float attn_factor_l = is_swa ? 1.0f : attn_factor;
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor_l, attn_factor_l, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor_l, attn_factor_l, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, is_swa ? KQ_mask_swa : KQ_mask,
n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il, nullptr, is_swa ? hparams.n_swa : 0);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
model.layers[il].ffn_gate_inp, nullptr,
model.layers[il].ffn_up_exps, nullptr,
model.layers[il].ffn_gate_exps, nullptr,

View File

@ -474,7 +474,7 @@ void llm_load_hparams(
if (hparams.n_swa > 0) {
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
hparams.rope_freq_scale_train_swa = 1; //hparams.rope_freq_scale_train;
if (!ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer, false)) {
for (uint32_t i = 0; i < hparams.n_layer; ++i) {

View File

@ -3077,6 +3077,7 @@ static bool is_model_split_supported(const llama_model & model) {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_GLM_DSA,
LLM_ARCH_MISTRAL4,
LLM_ARCH_MELLUM,
};
auto it = k_supported.find(model.arch);
return it != k_supported.end();
@ -6847,11 +6848,20 @@ struct llama_context * llama_init_from_model(
if (cparams.reduce_type == GGML_TYPE_F16) {
LLAMA_LOG_WARN("=====================================================================\n");
LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n");
LLAMA_LOG_WARN(" => changing cparams.reduce_type to f32\n");
LLAMA_LOG_WARN("=====================================================================\n");
cparams.reduce_type = GGML_TYPE_F32;
}
}
if (model->arch == LLM_ARCH_MELLUM && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
if (cparams.reduce_type == GGML_TYPE_F16) {
LLAMA_LOG_WARN("=====================================================================\n");
LLAMA_LOG_WARN("Mellum with split mode graph requires bf16 or f32 precision\n");
LLAMA_LOG_WARN(" => changing cparams.reduce_type to bf16\n");
LLAMA_LOG_WARN("=====================================================================\n");
cparams.reduce_type = GGML_TYPE_BF16;
}
}
if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&