Minor DFlash tweaks

This commit is contained in:
Kawrakow 2026-06-25 15:10:16 +00:00
parent b84902d2ad
commit a4e408611d

View File

@ -226,22 +226,31 @@ ggml_cgraph * llm_build_context::build_dflash() {
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Qcur, "Qcur", il);
cb(Kcur_noise, "Kcur_noise", il);
cb(Vcur_noise, "Vcur_noise", il);
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur_noise);
ggml_build_forward_expand(gf, Vcur_noise);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur_roped", il);
ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens); Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens);
Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Kcur_normed", il);
Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr, Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur_noise, "Kcur_noise", il); cb(Kcur_noise, "Kcur_roped", il);
ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens);
cb(Vcur_noise, "Vcur_noise", il); cb(Vcur_noise, "Vcur_noise", il);
@ -253,8 +262,10 @@ ggml_cgraph * llm_build_context::build_dflash() {
GGML_ASSERT(lctx.dflash.kv.k_ctx_cache[il]->ne[1] >= n_kv_total); GGML_ASSERT(lctx.dflash.kv.k_ctx_cache[il]->ne[1] >= n_kv_total);
GGML_ASSERT(lctx.dflash.kv.v_ctx_cache[il]->ne[1] >= n_kv_total); GGML_ASSERT(lctx.dflash.kv.v_ctx_cache[il]->ne[1] >= n_kv_total);
ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3)); //ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3));
ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3)); //ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3));
ggml_tensor * Kcur_draft = ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3);
ggml_tensor * Vcur_draft = ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3);
cb(Kcur_draft, "dflash_main_k_perm_cont", il); cb(Kcur_draft, "dflash_main_k_perm_cont", il);
cb(Vcur_draft, "dflash_main_v_perm_cont", il); cb(Vcur_draft, "dflash_main_v_perm_cont", il);
@ -296,15 +307,17 @@ ggml_cgraph * llm_build_context::build_dflash() {
ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
ggml_tensor * k = Kcur; ggml_tensor * k = Kcur;
ggml_tensor * v = Vcur; ggml_tensor * v = Vcur;
ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr) bool use_swa = hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr;
? dflash_kq_mask_swa ggml_tensor * dflash_kq_mask_l = use_swa ? dflash_kq_mask_swa : dflash_kq_mask_full;
: dflash_kq_mask_full;
cb(q, "q", il); cb(q, "q", il);
cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias, cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "flash_attn", il); cb(cur, "flash_attn", il);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
if (use_swa) {
cur->op_params[4] = hparams.n_swa;
}
cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens); cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il); cb(cur, "flash_attn_reshaped", il);