mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Split mode graph for Mellum (#1920)
This commit is contained in:
parent
dc51c6f9b2
commit
4406e637b5
@ -21,47 +21,13 @@ ggml_cgraph * llm_build_context::build_mellum() {
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k(il));
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
auto KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
|
||||
inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr, KQ_mask_l,
|
||||
nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, is_swa ? hparams.n_swa : 0, il, true, false, true);
|
||||
|
||||
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
|
||||
model.layers[il].wqkv, nullptr,
|
||||
model.layers[il].wqk, nullptr,
|
||||
model.layers[il].wq, nullptr,
|
||||
model.layers[il].wk, nullptr,
|
||||
model.layers[il].wv, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
|
||||
|
||||
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : freq_base;
|
||||
const float freq_scale_l = is_swa ? 1.0f : freq_scale;
|
||||
const float ext_factor_l = is_swa ? 0.0f : ext_factor;
|
||||
const float attn_factor_l = is_swa ? 1.0f : attn_factor;
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor_l, attn_factor_l, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor_l, attn_factor_l, beta_fast, beta_slow);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, is_swa ? KQ_mask_swa : KQ_mask,
|
||||
n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il, nullptr, is_swa ? hparams.n_swa : 0);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
||||
model.layers[il].ffn_gate_inp, nullptr,
|
||||
model.layers[il].ffn_up_exps, nullptr,
|
||||
model.layers[il].ffn_gate_exps, nullptr,
|
||||
|
||||
@ -474,7 +474,7 @@ void llm_load_hparams(
|
||||
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
||||
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
||||
hparams.rope_freq_scale_train_swa = 1; //hparams.rope_freq_scale_train;
|
||||
|
||||
if (!ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer, false)) {
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
|
||||
@ -3077,6 +3077,7 @@ static bool is_model_split_supported(const llama_model & model) {
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_GLM_DSA,
|
||||
LLM_ARCH_MISTRAL4,
|
||||
LLM_ARCH_MELLUM,
|
||||
};
|
||||
auto it = k_supported.find(model.arch);
|
||||
return it != k_supported.end();
|
||||
@ -6847,11 +6848,20 @@ struct llama_context * llama_init_from_model(
|
||||
if (cparams.reduce_type == GGML_TYPE_F16) {
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
|
||||
LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n");
|
||||
LLAMA_LOG_WARN(" => changing cparams.reduce_type to f32\n");
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
cparams.reduce_type = GGML_TYPE_F32;
|
||||
}
|
||||
}
|
||||
if (model->arch == LLM_ARCH_MELLUM && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
if (cparams.reduce_type == GGML_TYPE_F16) {
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
LLAMA_LOG_WARN("Mellum with split mode graph requires bf16 or f32 precision\n");
|
||||
LLAMA_LOG_WARN(" => changing cparams.reduce_type to bf16\n");
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
cparams.reduce_type = GGML_TYPE_BF16;
|
||||
}
|
||||
}
|
||||
|
||||
if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
|
||||
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user