From 022bd00aab9ec8428c4811275de89796c677d278 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 11 Jun 2026 07:26:42 +0200 Subject: [PATCH] Optimize Cohere2-MoE graph parallel (#1948) * Optimzie Cohere2-MoE graph parallel * Minor --- src/graphs/build_cohere2_moe.cpp | 7 ++++--- src/llama-build-context.cpp | 26 +++++++++++++++++++++++++- src/llama-build-context.h | 2 +- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/graphs/build_cohere2_moe.cpp b/src/graphs/build_cohere2_moe.cpp index 05492a5e..3de6e230 100644 --- a/src/graphs/build_cohere2_moe.cpp +++ b/src/graphs/build_cohere2_moe.cpp @@ -30,9 +30,10 @@ ggml_cgraph * llm_build_context::build_cohere2_moe() { inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } + attn_out->op_params[3] = 1; + ggml_tensor * cur; if (model.layers[il].ffn_gate_inp == nullptr) { - attn_out->op_params[3] = 1; cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL, model.layers[il].ffn_up, nullptr, nullptr, model.layers[il].ffn_gate, nullptr, nullptr, @@ -52,8 +53,8 @@ ggml_cgraph * llm_build_context::build_cohere2_moe() { n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, false, 0.0f, (llm_expert_gating_func_type) hparams.expert_gating_func, - LLM_FFN_SILU, cb, il, gf, false, model.layers[il].ffn_up_gate_exps, nullptr, nullptr); - cur = ggml_add(ctx0, cur, attn_out); + LLM_FFN_SILU, cb, il, gf, false, model.layers[il].ffn_up_gate_exps, nullptr, nullptr, + attn_out); } cb(cur, "ffn_out", il); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e7f3d7b4..1696812b 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1293,7 +1293,8 @@ llm_expert_gating_func_type gating_op, llm_ffn_op_type type_op_shexp, const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input, ggml_tensor * up_gate_exps, ggml_tensor * up_gate_exps_b, - ggml_tensor * shexp_gate) { + ggml_tensor * shexp_gate, + ggml_tensor * add_extra) { auto split_up_exps = up_exps ? (ggml_split_tensor_t *)up_exps->extra : nullptr; auto split_gate_exps = gate_exps ? (ggml_split_tensor_t *)gate_exps->extra : nullptr; @@ -1336,6 +1337,7 @@ llm_expert_gating_func_type gating_op, } ggml_build_forward_expand(graph, routed_out); + bool handled_add_extra = false; if (up_shexp && gate_shexp && down_shexp) { if (split_up_shexp) { std::vector results(split_up_shexp->n_device, nullptr); @@ -1387,6 +1389,12 @@ llm_expert_gating_func_type gating_op, shared_out = ggml_add(ctx, shared_out, routed_out); cb(shared_out, "ffn_shared_routed_added", il); } + if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) { + GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different + shared_out = ggml_add(ctx, shared_out, add_extra->src[id]); + cb(shared_out, "ffn_shared_with_extra", il_cb); + handled_add_extra = true; + } if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type); } @@ -1422,6 +1430,13 @@ llm_expert_gating_func_type gating_op, } else { cur = routed_out; } + if (add_extra && !handled_add_extra) { + if (add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) { + add_extra->op_params[3] = 0; + } + cur = ggml_add(ctx, cur, add_extra); + cb(cur, "ffn_with_extra", il); + } if (cur != routed_out) { ggml_build_forward_expand(graph, cur); } @@ -1505,6 +1520,11 @@ llm_expert_gating_func_type gating_op, } else { cur = routed_out; } + if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) { + GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different + cur = ggml_add(ctx, cur, add_extra->src[id]); + cb(cur, "ffn_with_extra", il_cb); + } if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { cur = ggml_cast(ctx, cur, lctx.cparams.reduce_type); cb(cur, "ffn_out_f16", il_cb); @@ -1519,6 +1539,10 @@ llm_expert_gating_func_type gating_op, results[last_id] = ggml_add(ctx, results[last_id], input); cb(results[last_id], "ffn_inp_added", il); } + if (add_extra && !(add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1)) { + results[last_id] = ggml_add(ctx, results[last_id], add_extra); + cb(results[last_id], "ffn_with_inp", il); + } auto cur = ggml_reduce(ctx, results.data(), n_device, GGML_OP_ADD); cb(cur, "moe_ffn_combined", il); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 7afc5818..63a33545 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -454,7 +454,7 @@ llm_expert_gating_func_type gating_op, llm_ffn_op_type type_op_shexp, const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input = false, ggml_tensor * up_gate_exps = nullptr, ggml_tensor * up_gate_exps_b = nullptr, - ggml_tensor * shexp_gate = nullptr); + ggml_tensor * shexp_gate = nullptr, ggml_tensor * add_extra = nullptr); static ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids);