Fuse some ops for Gemma4-MoE (#1610)

This commit is contained in:
Kawrakow 2026-04-11 08:11:54 +02:00 committed by GitHub
parent 2c455ec468
commit b0750b5d43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 5 deletions

View File

@ -68,6 +68,40 @@ static void mul_multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t
mul_multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, nb02, nb11, nb12, src0, src1, dst);
}
static __global__ void mul_multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, int64_t nb02, int64_t nb11, int64_t nb12, int64_t nb31,
const char * src0, const char * src1, char * dst, const float * scales, const char * cids) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
int64_t k = ne0*ne1;
if (i >= k) {
return;
}
int i1 = i / ne0;
int i0 = i % ne0;
float * result = (float *)(dst + i1*nb1);
const int * ids = (const int *)(cids + i1 * nb31);
auto c0 = src0 + i1*nb02;
auto c1 = src1 + i1*nb12;
float sum = 0;
for (int j = 0; j < nused; ++j) {
auto x0 = (const float *)c0;
auto x1 = (const float *)c1;
sum += x0[i0] * x1[0] * scales[ids[j]];
c0 += nb01;
c1 += nb11;
}
result[i0] = sum;
}
static void mul_multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, int64_t nb02, int64_t nb11, int64_t nb12, int64_t nb31,
const char * src0, const char * src1, char * dst, const float * scales, const char * ids, cudaStream_t stream) {
int64_t k = ne0 * ne1;
const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE;
mul_multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, nb02, nb11, nb12, nb31, src0, src1, dst, scales, ids);
}
void ggml_cuda_op_mul_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
auto src0 = dst->src[0];
auto src1 = dst->src[1];
@ -82,6 +116,19 @@ void ggml_cuda_op_mul_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * d
GGML_ASSERT(src0->ne[3] == 1);
GGML_ASSERT(src1->ne[0] == 1);
auto src2 = dst->src[2];
auto src3 = dst->src[3];
if (src2 && src3) {
GGML_ASSERT(src3->ne[0] == src0->ne[1]);
GGML_ASSERT(src3->type == GGML_TYPE_I32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
mul_multi_add_f32_cuda(src0->ne[1], dst->ne[0], dst->ne[1], dst->nb[1], src0->nb[1], src0->nb[2], src1->nb[1], src1->nb[2], src3->nb[1],
(const char *)src0->data, (const char *)src1->data, (char *)dst->data, (const float *)src2->data, (const char *)src3->data, ctx.stream());
return;
}
mul_multi_add_f32_cuda(src0->ne[1], dst->ne[0], dst->ne[1], dst->nb[1], src0->nb[1], src0->nb[2], src1->nb[1], src1->nb[2],
(const char *)src0->data, (const char *)src1->data, (char *)dst->data, ctx.stream());
}

View File

@ -449,6 +449,39 @@ void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth) {
int ne01 = src0->ne[1];
int ne00 = src0->ne[0];
auto src2 = dst->src[2];
auto src3 = dst->src[3];
if (src2 && src3) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src3->type == GGML_TYPE_I32);
GGML_ASSERT(src3->ne[0] == src0->ne[1]);
auto cids = (const char *)src3->data;
auto scales = (const float *)src2->data;
for (int ir = first; ir < last; ++ir) {
auto c0 = (const char *)src0->data + ir*src0->nb[2];
auto c1 = (const char *)src1->data + ir*src1->nb[2];
auto cy = ( char *)dst->data + ir* dst->nb[1];
auto y = ( float *)cy;
auto x0 = (const float *)c0;
auto x1 = (const float *)c1;
auto ids = (const int *)(cids + ir*src3->nb[1]);
float s = scales[ids[0]] * x1[0];
for (int k = 0; k < ne00; ++k) y[k] = x0[k] * s;
for (int j = 1; j < ne01; ++j) {
c0 += src0->nb[1];
c1 += src1->nb[1];
x0 = (const float *)c0;
x1 = (const float *)c1;
s = x1[0] * scales[ids[j]];
for (int k = 0; k < ne00; ++k) y[k] += x0[k] * s;
}
}
return;
}
for (int ir = first; ir < last; ++ir) {
auto c0 = (const char *)src0->data + ir*src0->nb[2];
auto c1 = (const char *)src1->data + ir*src1->nb[2];

View File

@ -1222,18 +1222,25 @@ llm_expert_gating_func_type gating_op,
cb(experts, "ffn_moe_down_biased", il);
}
if (down_exps_s) {
ggml_tensor * s = ggml_reshape_3d(ctx, down_exps_s, 1, n_expert, 1);
if (down_exps_s && !lctx.cparams.fused_mmad) {
GGML_ASSERT(!weight_before_ffn);
auto s = ggml_reshape_3d(ctx, down_exps_s, 1, n_expert, 1);
s = ggml_repeat_4d(ctx, s, 1, n_expert, n_tokens, 1);
s = ggml_get_rows(ctx, s, selected_experts); // [1, n_expert_used, n_tokens]
experts = ggml_mul(ctx, experts, s);
cb(experts, "ffn_moe_down_scaled", il);
s = ggml_get_rows(ctx, s, selected_experts);
auto w_reshaped = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
auto s_reshaped = ggml_reshape_2d(ctx, s, n_expert_used, n_tokens);
w_reshaped = ggml_mul(ctx, w_reshaped, s_reshaped);
weights = ggml_reshape_3d(ctx, w_reshaped, 1, n_expert_used, n_tokens);
}
if (!weight_before_ffn) {
if (lctx.cparams.fused_mmad) {
experts = ggml_mul_multi_add(ctx, experts, weights);
cb(experts, "ffn_moe_weighted", il);
if (down_exps_s) {
experts->src[2] = down_exps_s;
experts->src[3] = selected_experts;
}
if (add_input) {
experts = ggml_add(ctx, experts, input);
cb(experts, "ffn_out_with_inp", il);
@ -6299,10 +6306,15 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
if (is_moe) {
cur = do_split_norm(ctx0, ffn_inp[id], model.layers[il].ffn_pre_norm_2, hparams, cb, id, il_cb, false);
cb(cur, "ffn_moe_inp", il_cb);
auto tmp = ggml_rms_norm(ctx0, ffn_inp[id], hparams.f_norm_rms_eps);
cb(tmp, "tmp", il_cb);
tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) hparams.n_embd));
cb(tmp, "tmp_scaled", il_cb);
tmp = ggml_mul(ctx0, tmp, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp_s->extra)->splits[id]);
cb(tmp, "tmp_mul", il_cb);
auto logits = llm.llm_build_lora_mm(lctx, ctx0, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp->extra)->splits[id], tmp);
cb(logits, "logits", il_cb);
auto moe = llm. llm_build_moe_ffn(ctx0, lctx, cur,
nullptr, nullptr, nullptr,