mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Graph parallel for Gemma4 MoE (#1600)
* Use build_std_attention for Gemma4 when possible It is possible for the 26b MoE and 31b dense models. It is not possible for the E4B/E2B vaiants because they don't have KV cache in each layer. * Standardize Gemma4 dense ffn * WIP: Gemma4 split mode graph Runs but produces NaNs * WIP: Gemma4 split mode graph Runs but very high PPL. At least it is no longer NaN. * WIP * This works! * Put attn_norm, attn_post_norm, ffn_norm, ffn_post_norm on all GPUs * Fix crash when saving/loading KV cache * WIP: split mode graph for Gemma4-MoE - crashes * Split mode graph for Gemma4-MoE - this works * Disable SWA optimization Something goes wrong there * Consolidate MoE and dense graph parallel
This commit is contained in:
parent
9db5d9907e
commit
847e191936
@ -40,6 +40,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1]));
|
||||
int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE);
|
||||
int first = K->ne[1] - nton;
|
||||
local_dst = *dst;
|
||||
local_dst.op_params[4] = 0;
|
||||
if (first > 0) {
|
||||
local_dst = *dst;
|
||||
Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first;
|
||||
@ -51,6 +53,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
local_dst.op_params[4] = 0;
|
||||
dst = &local_dst;
|
||||
}
|
||||
dst = &local_dst;
|
||||
}
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
|
||||
@ -6042,13 +6042,18 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
GGML_ASSERT(cparams.flash_attn);
|
||||
auto gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false);
|
||||
|
||||
bool is_moe = hparams.n_expert > 0;
|
||||
|
||||
std::vector<ggml_tensor *> sa_inp(n_device, nullptr);
|
||||
std::vector<ggml_tensor *> sa_out(n_device, nullptr);
|
||||
std::vector<ggml_tensor *> ffn_inp(n_device, nullptr);
|
||||
std::vector<ggml_tensor *> ffn_out(n_device, nullptr);
|
||||
std::vector<ggml_tensor *> ffn_out_moe;
|
||||
if (is_moe) {
|
||||
ffn_out_moe.resize(n_device, nullptr);
|
||||
}
|
||||
|
||||
//ggml_tensor * last_ffn_inp = nullptr;
|
||||
//ggml_tensor * last_sa_inp = nullptr;
|
||||
ggml_tensor * inpL_moe = nullptr;
|
||||
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & l = model.layers[il];
|
||||
@ -6057,9 +6062,6 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
const float freq_scale_l = is_sliding ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
||||
const int n_rot_l = is_sliding ? hparams.n_rot_swa : hparams.n_rot;
|
||||
const int n_swa = is_sliding ? hparams.n_swa : 0;
|
||||
//const int n_embd_head = hparams.n_embd_head_k(il);
|
||||
//const int n_head = hparams.n_head(il);
|
||||
//const int n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
|
||||
@ -6099,8 +6101,18 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
} else {
|
||||
GGML_ASSERT(inpL->op == GGML_OP_REDUCE);
|
||||
auto cur = get_input_tensor_sm_graph(ctx0, inpL, id);
|
||||
if (is_moe) {
|
||||
GGML_ASSERT(inpL_moe && inpL_moe->op == GGML_OP_REDUCE);
|
||||
cur = do_split_norm(ctx0, cur, model.layers[il-1].ffn_post_norm_1, hparams, cb, id, il_cb, false);
|
||||
cb(cur, "ffn_post_norm", il_cb);
|
||||
auto cur_moe = get_input_tensor_sm_graph(ctx0, inpL_moe, id);
|
||||
cur_moe = do_split_norm(ctx0, cur_moe, model.layers[il-1].ffn_post_norm_2, hparams, cb, id, il_cb, false);
|
||||
cb(cur, "ffn_moe_post_norm", il_cb);
|
||||
cur = ggml_add(ctx0, cur, cur_moe);
|
||||
cb(cur, "ffn_combined", il_cb);
|
||||
}
|
||||
cur = do_split_norm(ctx0, cur, model.layers[il-1].ffn_post_norm, hparams, cb, id, il_cb, false);
|
||||
cb(cur, "ffn_post_norm", il_cb);
|
||||
cb(cur, "ffn_normed", il_cb);
|
||||
auto add = ffn_inp[id];
|
||||
if (!add) {
|
||||
for (int j = 0; j < n_device; ++j) {
|
||||
@ -6214,6 +6226,11 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
ggml_row_size(split_vl->type, n_embd_head_v), 0);
|
||||
cb(v, "v", il_cb);
|
||||
|
||||
//if (il == 0 || il == 5) {
|
||||
// if (il == 0 && id == 0) printf("\n");
|
||||
// if (id == 0) printf("--- il = %d\n", il);
|
||||
// printf("id = %d, q: %ld x %ld x %ld x %ld, k: %ld x %ld x %ld x %ld v: %ld x %ld x %ld x %ld\n", id, q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3], v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
//}
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask_l, hparams.f_attention_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "fa", il_cb);
|
||||
@ -6285,12 +6302,43 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
ffn_out[id] = cur;
|
||||
|
||||
if (is_moe) {
|
||||
cur = do_split_norm(ctx0, ffn_inp[id], model.layers[il].ffn_pre_norm_2, hparams, cb, id, il_cb, false);
|
||||
auto tmp = ggml_rms_norm(ctx0, ffn_inp[id], hparams.f_norm_rms_eps);
|
||||
tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) hparams.n_embd));
|
||||
tmp = ggml_mul(ctx0, tmp, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp_s->extra)->splits[id]);
|
||||
auto logits = llm.llm_build_lora_mm(lctx, ctx0, ((const ggml_split_tensor_t *)model.layers[il].ffn_gate_inp->extra)->splits[id], tmp);
|
||||
|
||||
auto moe = llm. llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
nullptr, nullptr, nullptr,
|
||||
((const ggml_split_tensor_t *)model.layers[il].ffn_down_exps->extra)->splits[id], nullptr,
|
||||
llm.n_expert, llm.n_expert_used,
|
||||
LLM_FFN_GELU, true, false, 0.0f,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il, gf, false,
|
||||
((const ggml_split_tensor_t *)model.layers[il].ffn_up_gate_exps->extra)->splits[id],
|
||||
nullptr, logits, ((const ggml_split_tensor_t *)model.layers[il].ffn_down_exps_s->extra)->splits[id]);
|
||||
if (moe->ne[1] > 32 && cparams.reduce_type != GGML_TYPE_F32) {
|
||||
moe = ggml_cast(ctx0, moe, cparams.reduce_type);
|
||||
cb(moe, "ffn_moe_cast", il_cb);
|
||||
}
|
||||
ggml_build_forward_expand(gf, moe);
|
||||
ffn_out_moe[id] = moe;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
inpL = ggml_reduce(ctx0, ffn_out.data(), n_device, GGML_OP_ADD);
|
||||
cb(inpL, "ffn_reduce", il);
|
||||
ggml_build_forward_expand(gf, inpL);
|
||||
|
||||
if (is_moe) {
|
||||
inpL_moe = ggml_reduce(ctx0, ffn_out_moe.data(), n_device, GGML_OP_ADD);
|
||||
cb(inpL_moe, "ffn_moe_reduce", il);
|
||||
ggml_build_forward_expand(gf, inpL_moe);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int idx = lctx.model.default_layer_device[lctx.model.hparams.n_layer];
|
||||
@ -6301,9 +6349,25 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
cur = inpL->view_src;
|
||||
}
|
||||
|
||||
auto post_norm = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm->extra;
|
||||
auto post_norm = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm->extra;
|
||||
if (is_moe) {
|
||||
auto cur_moe = inpL_moe->src[idx];
|
||||
if (!cur_moe) {
|
||||
cur_moe = inpL_moe->view_src;
|
||||
}
|
||||
auto post_norm_1 = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm_1->extra;
|
||||
auto post_norm_2 = (const ggml_split_tensor_t *)model.layers[hparams.n_layer-1].ffn_post_norm_2->extra;
|
||||
cur = llm.llm_build_norm(ctx0, cur, hparams, post_norm_1->splits[idx], NULL, LLM_NORM_RMS, cb, -1);
|
||||
cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
cur_moe = llm.llm_build_norm(ctx0, cur_moe, hparams, post_norm_2->splits[idx], NULL, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "ffn_post", hparams.n_layer-1);
|
||||
cb(cur_moe, "ffn_post_moe", hparams.n_layer-1);
|
||||
cur = ggml_add(ctx0, cur, cur_moe);
|
||||
cb(cur, "ffn_combined", hparams.n_layer-1);
|
||||
}
|
||||
cur = llm.llm_build_norm(ctx0, cur, hparams, post_norm->splits[idx], NULL, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "ffn_post", hparams.n_layer-1);
|
||||
cb(cur, "ffn_normed", hparams.n_layer-1);
|
||||
auto add = ffn_inp[idx];
|
||||
if (!add) {
|
||||
for (int j = 0; j < n_device; ++j) {
|
||||
@ -6322,6 +6386,7 @@ static ggml_cgraph * build_gemma4_graph_paralle(llm_build_context & llm, llama_c
|
||||
}
|
||||
|
||||
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
cb(cur, "almost_result", -1);
|
||||
if (hparams.f_final_logit_softcapping > 0) {
|
||||
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
|
||||
}
|
||||
@ -6405,7 +6470,7 @@ ggml_cgraph * llm_build_context::build_gemma4() {
|
||||
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
return build_gemma4_graph_paralle(*this, lctx, ctx0, inpL, inp_pos, inp_out_ids,
|
||||
KQ_mask, KQ_mask_swa, n_tokens, cb);
|
||||
KQ_mask, KQ_mask_swa, n_tokens, cb);
|
||||
}
|
||||
|
||||
auto gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false);
|
||||
|
||||
@ -4057,12 +4057,6 @@ bool create_tensors_helper::create_tensors() {
|
||||
if (model.tok_embd_per_layer) {
|
||||
supported = false;
|
||||
}
|
||||
for (auto & l : model.layers) {
|
||||
if (l.ffn_gate_inp) {
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!supported) {
|
||||
LLAMA_LOG_WARN("\n=========================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this Gemma4 variant\n");
|
||||
@ -4268,6 +4262,31 @@ bool create_tensors_helper::create_tensors() {
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_post_norm, layer.split_ffn_post_norm, mirror, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_post_norm_1) {
|
||||
if (auto it = split_tensors.find(layer.ffn_post_norm_1); it != split_tensors.end()) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_post_norm_1, layer.split_ffn_post_norm_1, mirror, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_post_norm_2) {
|
||||
if (auto it = split_tensors.find(layer.ffn_post_norm_2); it != split_tensors.end()) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_post_norm_2, layer.split_ffn_post_norm_2, mirror, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_pre_norm_2) {
|
||||
if (auto it = split_tensors.find(layer.ffn_pre_norm_2); it != split_tensors.end()) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_pre_norm_2, layer.split_ffn_pre_norm_2, mirror, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_down_exps_s) {
|
||||
if (auto it = split_tensors.find(layer.ffn_down_exps_s); it != split_tensors.end()) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_down_exps_s, layer.split_ffn_down_exps_s, mirror, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_gate_inp_s) {
|
||||
if (auto it = split_tensors.find(layer.ffn_gate_inp_s); it != split_tensors.end()) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp_s, layer.split_ffn_gate_inp_s, mirror, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down) != split_tensors.end() &&
|
||||
|
||||
@ -259,6 +259,9 @@ struct llama_layer {
|
||||
llama_split_tensor split_ffn_norm;
|
||||
llama_split_tensor split_ffn_up_gate;
|
||||
llama_split_tensor split_ffn_post_norm;
|
||||
llama_split_tensor split_ffn_post_norm_1;
|
||||
llama_split_tensor split_ffn_post_norm_2;
|
||||
llama_split_tensor split_ffn_pre_norm_2;
|
||||
|
||||
// ff MoE
|
||||
struct ggml_tensor * ffn_gate_inp = nullptr;
|
||||
@ -300,6 +303,8 @@ struct llama_layer {
|
||||
llama_split_tensor split_ffn_down_exps_b;
|
||||
llama_split_tensor split_ffn_up_exps_b;
|
||||
llama_split_tensor split_ffn_up_gate_exps_b;
|
||||
llama_split_tensor split_ffn_down_exps_s;
|
||||
llama_split_tensor split_ffn_gate_inp_s;
|
||||
|
||||
// ff bias
|
||||
struct ggml_tensor * ffn_gate_b = nullptr;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user