Split mode graph for Laguna (#1939)

This commit is contained in:
Kawrakow 2026-06-09 10:13:30 +02:00 committed by GitHub
parent 11c3546235
commit 2768b62515
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 105 deletions

View File

@ -7,7 +7,7 @@ ggml_cgraph * llm_build_context::build_laguna() {
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
ggml_tensor * KQ_mask = build_inp_KQ_mask();
ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
@ -15,105 +15,22 @@ ggml_cgraph * llm_build_context::build_laguna() {
const bool is_swa = hparams.swa_layers[il];
const int n_swa_l = is_swa ? hparams.n_swa : 0;
ggml_tensor * inpSA = inpL;
auto KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
auto rope_factors = is_swa ? nullptr : build_rope_factors(il);
ggml_tensor * cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
ggml_tensor * input_normed = cur;
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
const int64_t n_head_l = hparams.n_head(il);
const int64_t n_head_kv_l = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k(il);
const int64_t n_embd_head_v = hparams.n_embd_head_v(il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
if (model.layers[il].attn_q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
ggml_build_forward_expand(gf, Qcur);
}
if (model.layers[il].attn_k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
ggml_build_forward_expand(gf, Kcur);
}
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
const float ext_factor_l = is_swa ? 0.0f : ext_factor;
const float attn_factor_l = is_swa ? 1.0f : attn_factor;
const float beta_fast_l = is_swa ? 32.0f : beta_fast;
const float beta_slow_l = is_swa ? 1.0f : beta_slow;
ggml_tensor * rope_factors = is_swa ? nullptr : build_rope_factors(il);
const int n_rot_l = hparams.rope_n_rot(il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
cb(Qcur, "Qcur_roped", il);
cb(Kcur, "Kcur_roped", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
nullptr, nullptr,
Kcur, Vcur, Qcur,
is_swa ? KQ_mask_swa : KQ_mask,
n_tokens, kv_head, n_kv,
1.0f / sqrtf(float(n_embd_head_k)), cb, il, nullptr, n_swa_l);
cb(cur, "attn_out", il);
if (model.layers[il].wqkv_gate) {
ggml_tensor * gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input_normed);
cb(gate, "attn_gate", il);
gate = ggml_softplus(ctx0, gate);
cb(gate, "attn_gate_softplus", il);
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
cb(gate_3d, "attn_gate_3d", il);
cur = ggml_mul(ctx0, attn_3d, gate_3d);
cb(cur, "attn_gated_3d", il);
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v * n_head_l, n_tokens);
cb(cur, "attn_gated", il);
}
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
cb(cur, "attn_proj", il);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
auto cur = build_std_attention(gf, model.layers[il].attn_norm, inpL,
inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, rope_factors,
KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head_k)), 0.0f, n_swa_l, il, true, false, true);
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, //ffn_inp,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
} else {
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, //ffn_inp,
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,

View File

@ -2549,7 +2549,17 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
if (hparams.has_rope_freq_base_per_layer) {
freq_base_l = hparams.rope_freq_base_per_layer[il];
}
int n_rot_l = lctx.model.hparams.rope_n_rot(il);
int n_rot_l = lctx.model.hparams.rope_n_rot(il);
float ext_factor_l = ext_factor;
float attn_factor_l = attn_factor;
float beta_fast_l = beta_fast;
float beta_slow_l = beta_slow;
if (model.arch == LLM_ARCH_LAGUNA && n_swa > 0) {
ext_factor_l = 0.0f;
attn_factor_l = 1.0f;
beta_fast_l = 32.0f;
beta_slow_l = 1.0f;
}
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_precision = true;
#else
@ -2650,15 +2660,15 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
std::copy(hparams.rope_sections.begin(), hparams.rope_sections.begin() + GGML_MROPE_SECTIONS, sections);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, rope_factors,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, rope_factors,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
}
}
cb(Qcur, "Qcur", il_cb);
@ -2765,11 +2775,18 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
auto wqkv_gate = (ggml_split_tensor_t *)model.layers[il].wqkv_gate->extra;
GGML_ASSERT(wqkv_gate && wqkv_gate->splits[id]);
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate->splits[id], input_normed);
if (model.arch == LLM_ARCH_LAGUNA) {
gate = ggml_softplus(ctx0, gate);
}
cb(gate, "attn_gate", il_cb);
int nh = split_wo->ne[0]/n_embd_head_v;
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, nh, n_tokens);
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, nh, n_tokens);
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
if (model.arch == LLM_ARCH_LAGUNA) {
cur = ggml_mul(ctx0, attn_3d, gate_3d);
} else {
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
}
cb(attn_3d, "attn_gated_3d", il_cb);
}
@ -2861,15 +2878,15 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
std::copy(hparams.rope_sections.begin(), hparams.rope_sections.begin() + GGML_MROPE_SECTIONS, sections);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, rope_factors_in,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, rope_factors_in,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
ext_factor_l, attn_factor_l, beta_fast_l, beta_slow_l);
}
}
cb(Qcur, "Qcur_roped", il);
@ -2886,11 +2903,18 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
cb(cur, "wqkv", il);
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate, input_normed); // [n_head_l, n_tokens]
if (model.arch == LLM_ARCH_LAGUNA) {
gate = ggml_softplus(ctx0, gate);
}
cb(gate, "attn_gate", il);
int n_head_l = hparams.n_head(il);
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
if (model.arch == LLM_ARCH_LAGUNA) {
cur = ggml_mul(ctx0, attn_3d, gate_3d);
} else {
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
}
cb(cur, "attn_gated_3d", il);
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v * n_head_l, n_tokens);
cb(cur, "attn_gated", il);

View File

@ -3079,6 +3079,7 @@ static bool is_model_split_supported(const llama_model & model) {
LLM_ARCH_GLM_DSA,
LLM_ARCH_MISTRAL4,
LLM_ARCH_MELLUM,
LLM_ARCH_LAGUNA,
};
auto it = k_supported.find(model.arch);
return it != k_supported.end();
@ -6877,10 +6878,11 @@ struct llama_context * llama_init_from_model(
cparams.reduce_type = GGML_TYPE_F32;
}
}
if (model->arch == LLM_ARCH_MELLUM && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
if ((model->arch == LLM_ARCH_MELLUM || model->arch == LLM_ARCH_LAGUNA) && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
if (cparams.reduce_type == GGML_TYPE_F16) {
const char * mname = model->arch == LLM_ARCH_MELLUM ? "Mellum" : "Laguna";
LLAMA_LOG_WARN("=====================================================================\n");
LLAMA_LOG_WARN("Mellum with split mode graph requires bf16 or f32 precision\n");
LLAMA_LOG_WARN("%s with split mode graph requires bf16 or f32 precision\n", mname);
LLAMA_LOG_WARN(" => changing cparams.reduce_type to bf16\n");
LLAMA_LOG_WARN("=====================================================================\n");
cparams.reduce_type = GGML_TYPE_BF16;