From b0750b5d43b3d24650c8a3fa6340446cf64d39d9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 11 Apr 2026 08:11:54 +0200 Subject: [PATCH] Fuse some ops for Gemma4-MoE (#1610) --- ggml/src/ggml-cuda/multiadd.cu | 47 ++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_cpu_ops.cpp | 33 ++++++++++++++++++++++++ src/llama-build-context.cpp | 22 ++++++++++++---- 3 files changed, 97 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/multiadd.cu b/ggml/src/ggml-cuda/multiadd.cu index fba7271a..ed2f4328 100644 --- a/ggml/src/ggml-cuda/multiadd.cu +++ b/ggml/src/ggml-cuda/multiadd.cu @@ -68,6 +68,40 @@ static void mul_multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t mul_multi_add_f32<<>>(nused, ne0, ne1, nb1, nb01, nb02, nb11, nb12, src0, src1, dst); } +static __global__ void mul_multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, int64_t nb02, int64_t nb11, int64_t nb12, int64_t nb31, + const char * src0, const char * src1, char * dst, const float * scales, const char * cids) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + int64_t k = ne0*ne1; + if (i >= k) { + return; + } + int i1 = i / ne0; + int i0 = i % ne0; + float * result = (float *)(dst + i1*nb1); + + const int * ids = (const int *)(cids + i1 * nb31); + + auto c0 = src0 + i1*nb02; + auto c1 = src1 + i1*nb12; + + float sum = 0; + for (int j = 0; j < nused; ++j) { + auto x0 = (const float *)c0; + auto x1 = (const float *)c1; + sum += x0[i0] * x1[0] * scales[ids[j]]; + c0 += nb01; + c1 += nb11; + } + result[i0] = sum; +} + +static void mul_multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, int64_t nb02, int64_t nb11, int64_t nb12, int64_t nb31, + const char * src0, const char * src1, char * dst, const float * scales, const char * ids, cudaStream_t stream) { + int64_t k = ne0 * ne1; + const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE; + mul_multi_add_f32<<>>(nused, ne0, ne1, nb1, nb01, nb02, nb11, nb12, nb31, src0, src1, dst, scales, ids); +} + void ggml_cuda_op_mul_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { auto src0 = dst->src[0]; auto src1 = dst->src[1]; @@ -82,6 +116,19 @@ void ggml_cuda_op_mul_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * d GGML_ASSERT(src0->ne[3] == 1); GGML_ASSERT(src1->ne[0] == 1); + auto src2 = dst->src[2]; + auto src3 = dst->src[3]; + if (src2 && src3) { + GGML_ASSERT(src3->ne[0] == src0->ne[1]); + GGML_ASSERT(src3->type == GGML_TYPE_I32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + + mul_multi_add_f32_cuda(src0->ne[1], dst->ne[0], dst->ne[1], dst->nb[1], src0->nb[1], src0->nb[2], src1->nb[1], src1->nb[2], src3->nb[1], + (const char *)src0->data, (const char *)src1->data, (char *)dst->data, (const float *)src2->data, (const char *)src3->data, ctx.stream()); + + return; + } + mul_multi_add_f32_cuda(src0->ne[1], dst->ne[0], dst->ne[1], dst->nb[1], src0->nb[1], src0->nb[2], src1->nb[1], src1->nb[2], (const char *)src0->data, (const char *)src1->data, (char *)dst->data, ctx.stream()); } diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 4a0e8eee..7a91336b 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -449,6 +449,39 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth) { int ne01 = src0->ne[1]; int ne00 = src0->ne[0]; + auto src2 = dst->src[2]; + auto src3 = dst->src[3]; + if (src2 && src3) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src3->type == GGML_TYPE_I32); + GGML_ASSERT(src3->ne[0] == src0->ne[1]); + + auto cids = (const char *)src3->data; + auto scales = (const float *)src2->data; + for (int ir = first; ir < last; ++ir) { + auto c0 = (const char *)src0->data + ir*src0->nb[2]; + auto c1 = (const char *)src1->data + ir*src1->nb[2]; + auto cy = ( char *)dst->data + ir* dst->nb[1]; + auto y = ( float *)cy; + auto x0 = (const float *)c0; + auto x1 = (const float *)c1; + auto ids = (const int *)(cids + ir*src3->nb[1]); + float s = scales[ids[0]] * x1[0]; + for (int k = 0; k < ne00; ++k) y[k] = x0[k] * s; + for (int j = 1; j < ne01; ++j) { + c0 += src0->nb[1]; + c1 += src1->nb[1]; + x0 = (const float *)c0; + x1 = (const float *)c1; + s = x1[0] * scales[ids[j]]; + for (int k = 0; k < ne00; ++k) y[k] += x0[k] * s; + } + } + + return; + + } + for (int ir = first; ir < last; ++ir) { auto c0 = (const char *)src0->data + ir*src0->nb[2]; auto c1 = (const char *)src1->data + ir*src1->nb[2]; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 69e09993..335794dd 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1222,18 +1222,25 @@ llm_expert_gating_func_type gating_op, cb(experts, "ffn_moe_down_biased", il); } - if (down_exps_s) { - ggml_tensor * s = ggml_reshape_3d(ctx, down_exps_s, 1, n_expert, 1); + if (down_exps_s && !lctx.cparams.fused_mmad) { + GGML_ASSERT(!weight_before_ffn); + auto s = ggml_reshape_3d(ctx, down_exps_s, 1, n_expert, 1); s = ggml_repeat_4d(ctx, s, 1, n_expert, n_tokens, 1); - s = ggml_get_rows(ctx, s, selected_experts); // [1, n_expert_used, n_tokens] - experts = ggml_mul(ctx, experts, s); - cb(experts, "ffn_moe_down_scaled", il); + s = ggml_get_rows(ctx, s, selected_experts); + auto w_reshaped = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + auto s_reshaped = ggml_reshape_2d(ctx, s, n_expert_used, n_tokens); + w_reshaped = ggml_mul(ctx, w_reshaped, s_reshaped); + weights = ggml_reshape_3d(ctx, w_reshaped, 1, n_expert_used, n_tokens); } if (!weight_before_ffn) { if (lctx.cparams.fused_mmad) { experts = ggml_mul_multi_add(ctx, experts, weights); cb(experts, "ffn_moe_weighted", il); + if (down_exps_s) { + experts->src[2] = down_exps_s; + experts->src[3] = selected_experts; + } if (add_input) { experts = ggml_add(ctx, experts, input); cb(experts, "ffn_out_with_inp", il); @@ -6299,10 +6306,15 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c if (is_moe) { cur = do_split_norm(ctx0, ffn_inp[id], model.layers[il].ffn_pre_norm_2, hparams, cb, id, il_cb, false); + cb(cur, "ffn_moe_inp", il_cb); auto tmp = ggml_rms_norm(ctx0, ffn_inp[id], hparams.f_norm_rms_eps); + cb(tmp, "tmp", il_cb); tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) hparams.n_embd)); + cb(tmp, "tmp_scaled", il_cb); tmp = ggml_mul(ctx0, tmp, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp_s->extra)->splits[id]); + cb(tmp, "tmp_mul", il_cb); auto logits = llm.llm_build_lora_mm(lctx, ctx0, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp->extra)->splits[id], tmp); + cb(logits, "logits", il_cb); auto moe = llm. llm_build_moe_ffn(ctx0, lctx, cur, nullptr, nullptr, nullptr,