Optimize Cohere2-MoE graph parallel (#1948)

* Optimzie Cohere2-MoE graph parallel

* Minor
This commit is contained in:
Kawrakow 2026-06-11 07:26:42 +02:00 committed by GitHub
parent ca0c1c5f85
commit 022bd00aab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 5 deletions

View File

@ -30,9 +30,10 @@ ggml_cgraph * llm_build_context::build_cohere2_moe() {
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
attn_out->op_params[3] = 1;
ggml_tensor * cur;
if (model.layers[il].ffn_gate_inp == nullptr) {
attn_out->op_params[3] = 1;
cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
@ -52,8 +53,8 @@ ggml_cgraph * llm_build_context::build_cohere2_moe() {
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm, false, 0.0f,
(llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, false, model.layers[il].ffn_up_gate_exps, nullptr, nullptr);
cur = ggml_add(ctx0, cur, attn_out);
LLM_FFN_SILU, cb, il, gf, false, model.layers[il].ffn_up_gate_exps, nullptr, nullptr,
attn_out);
}
cb(cur, "ffn_out", il);

View File

@ -1293,7 +1293,8 @@ llm_expert_gating_func_type gating_op,
llm_ffn_op_type type_op_shexp,
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input,
ggml_tensor * up_gate_exps, ggml_tensor * up_gate_exps_b,
ggml_tensor * shexp_gate) {
ggml_tensor * shexp_gate,
ggml_tensor * add_extra) {
auto split_up_exps = up_exps ? (ggml_split_tensor_t *)up_exps->extra : nullptr;
auto split_gate_exps = gate_exps ? (ggml_split_tensor_t *)gate_exps->extra : nullptr;
@ -1336,6 +1337,7 @@ llm_expert_gating_func_type gating_op,
}
ggml_build_forward_expand(graph, routed_out);
bool handled_add_extra = false;
if (up_shexp && gate_shexp && down_shexp) {
if (split_up_shexp) {
std::vector<ggml_tensor *> results(split_up_shexp->n_device, nullptr);
@ -1387,6 +1389,12 @@ llm_expert_gating_func_type gating_op,
shared_out = ggml_add(ctx, shared_out, routed_out);
cb(shared_out, "ffn_shared_routed_added", il);
}
if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) {
GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different
shared_out = ggml_add(ctx, shared_out, add_extra->src[id]);
cb(shared_out, "ffn_shared_with_extra", il_cb);
handled_add_extra = true;
}
if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type);
}
@ -1422,6 +1430,13 @@ llm_expert_gating_func_type gating_op,
} else {
cur = routed_out;
}
if (add_extra && !handled_add_extra) {
if (add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) {
add_extra->op_params[3] = 0;
}
cur = ggml_add(ctx, cur, add_extra);
cb(cur, "ffn_with_extra", il);
}
if (cur != routed_out) {
ggml_build_forward_expand(graph, cur);
}
@ -1505,6 +1520,11 @@ llm_expert_gating_func_type gating_op,
} else {
cur = routed_out;
}
if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) {
GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different
cur = ggml_add(ctx, cur, add_extra->src[id]);
cb(cur, "ffn_with_extra", il_cb);
}
if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, cur, lctx.cparams.reduce_type);
cb(cur, "ffn_out_f16", il_cb);
@ -1519,6 +1539,10 @@ llm_expert_gating_func_type gating_op,
results[last_id] = ggml_add(ctx, results[last_id], input);
cb(results[last_id], "ffn_inp_added", il);
}
if (add_extra && !(add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1)) {
results[last_id] = ggml_add(ctx, results[last_id], add_extra);
cb(results[last_id], "ffn_with_inp", il);
}
auto cur = ggml_reduce(ctx, results.data(), n_device, GGML_OP_ADD);
cb(cur, "moe_ffn_combined", il);

View File

@ -454,7 +454,7 @@ llm_expert_gating_func_type gating_op,
llm_ffn_op_type type_op_shexp,
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input = false,
ggml_tensor * up_gate_exps = nullptr, ggml_tensor * up_gate_exps_b = nullptr,
ggml_tensor * shexp_gate = nullptr);
ggml_tensor * shexp_gate = nullptr, ggml_tensor * add_extra = nullptr);
static ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids);