diff --git a/src/graphs/build_mellum.cpp b/src/graphs/build_mellum.cpp index 8b62ed9d..eb84fdad 100644 --- a/src/graphs/build_mellum.cpp +++ b/src/graphs/build_mellum.cpp @@ -21,47 +21,13 @@ ggml_cgraph * llm_build_context::build_mellum() { GGML_ASSERT(n_embd_head == hparams.n_embd_head_k(il)); GGML_ASSERT(n_embd_head == hparams.n_rot); - struct ggml_tensor * inpSA = inpL; + auto KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask; - cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, + inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr, KQ_mask_l, + nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, is_swa ? hparams.n_swa : 0, il, true, false, true); - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, - model.layers[il].wqkv, nullptr, - model.layers[il].wqk, nullptr, - model.layers[il].wq, nullptr, - model.layers[il].wk, nullptr, - model.layers[il].wv, nullptr, - model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il); - - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : freq_base; - const float freq_scale_l = is_swa ? 1.0f : freq_scale; - const float ext_factor_l = is_swa ? 0.0f : ext_factor; - const float attn_factor_l = is_swa ? 1.0f : attn_factor; - - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor_l, attn_factor_l, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor_l, attn_factor_l, beta_fast, beta_slow); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, is_swa ? KQ_mask_swa : KQ_mask, - n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il, nullptr, is_swa ? hparams.n_swa : 0); - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, + cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, model.layers[il].ffn_gate_inp, nullptr, model.layers[il].ffn_up_exps, nullptr, model.layers[il].ffn_gate_exps, nullptr, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 687df666..448774ae 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -474,7 +474,7 @@ void llm_load_hparams( if (hparams.n_swa > 0) { hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + hparams.rope_freq_scale_train_swa = 1; //hparams.rope_freq_scale_train; if (!ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer, false)) { for (uint32_t i = 0; i < hparams.n_layer; ++i) { diff --git a/src/llama.cpp b/src/llama.cpp index a5d05e60..73417724 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3077,6 +3077,7 @@ static bool is_model_split_supported(const llama_model & model) { LLM_ARCH_DEEPSEEK2, LLM_ARCH_GLM_DSA, LLM_ARCH_MISTRAL4, + LLM_ARCH_MELLUM, }; auto it = k_supported.find(model.arch); return it != k_supported.end(); @@ -6847,11 +6848,20 @@ struct llama_context * llama_init_from_model( if (cparams.reduce_type == GGML_TYPE_F16) { LLAMA_LOG_WARN("=====================================================================\n"); LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n"); - LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n"); + LLAMA_LOG_WARN(" => changing cparams.reduce_type to f32\n"); LLAMA_LOG_WARN("=====================================================================\n"); cparams.reduce_type = GGML_TYPE_F32; } } + if (model->arch == LLM_ARCH_MELLUM && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) { + if (cparams.reduce_type == GGML_TYPE_F16) { + LLAMA_LOG_WARN("=====================================================================\n"); + LLAMA_LOG_WARN("Mellum with split mode graph requires bf16 or f32 precision\n"); + LLAMA_LOG_WARN(" => changing cparams.reduce_type to bf16\n"); + LLAMA_LOG_WARN("=====================================================================\n"); + cparams.reduce_type = GGML_TYPE_BF16; + } + } if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 && model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&