From 1255b1e479a3ac8b63f3553bb0e25772a8de711b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 26 Jun 2026 10:31:03 +0200 Subject: [PATCH] Minor DFlash tweaks (#2034) --- src/graphs/build_dflash.cpp | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index ed867e10..273372e6 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -226,22 +226,31 @@ ggml_cgraph * llm_build_context::build_dflash() { cb(cur, "attn_norm", il); ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Qcur, "Qcur", il); + cb(Kcur_noise, "Kcur_noise", il); + cb(Vcur_noise, "Vcur_noise", il); + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur_noise); + ggml_build_forward_expand(gf, Vcur_noise); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - cb(Qcur, "Qcur", il); + cb(Qcur, "Qcur_roped", il); - ggml_tensor * Kcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); Kcur_noise = ggml_reshape_3d(ctx0, Kcur_noise, n_embd_head_k, n_head_kv, n_tokens); Kcur_noise = llm_build_norm(ctx0, Kcur_noise, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Qcur, "Kcur_normed", il); Kcur_noise = ggml_rope_ext(ctx0, Kcur_noise, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - cb(Kcur_noise, "Kcur_noise", il); + cb(Kcur_noise, "Kcur_roped", il); - ggml_tensor * Vcur_noise = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); Vcur_noise = ggml_reshape_3d(ctx0, Vcur_noise, n_embd_head_v, n_head_kv, n_tokens); cb(Vcur_noise, "Vcur_noise", il); @@ -253,8 +262,10 @@ ggml_cgraph * llm_build_context::build_dflash() { GGML_ASSERT(lctx.dflash.kv.k_ctx_cache[il]->ne[1] >= n_kv_total); GGML_ASSERT(lctx.dflash.kv.v_ctx_cache[il]->ne[1] >= n_kv_total); - ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3)); - ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3)); + //ggml_tensor * Kcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3)); + //ggml_tensor * Vcur_draft = ggml_cont(ctx0, ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3)); + ggml_tensor * Kcur_draft = ggml_permute(ctx0, Kcur_noise, 0, 2, 1, 3); + ggml_tensor * Vcur_draft = ggml_permute(ctx0, Vcur_noise, 0, 2, 1, 3); cb(Kcur_draft, "dflash_main_k_perm_cont", il); cb(Vcur_draft, "dflash_main_v_perm_cont", il); @@ -296,15 +307,17 @@ ggml_cgraph * llm_build_context::build_dflash() { ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); ggml_tensor * k = Kcur; ggml_tensor * v = Vcur; - ggml_tensor * dflash_kq_mask_l = (hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr) - ? dflash_kq_mask_swa - : dflash_kq_mask_full; + bool use_swa = hparams.swa_layers[il] && dflash_kq_mask_swa != nullptr; + ggml_tensor * dflash_kq_mask_l = use_swa ? dflash_kq_mask_swa : dflash_kq_mask_full; cb(q, "q", il); cur = ggml_flash_attn_ext(ctx0, q, k, v, dflash_kq_mask_l, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, "flash_attn", il); ggml_build_forward_expand(gf, cur); + if (use_swa) { + cur->op_params[4] = hparams.n_swa; + } cur = ggml_reshape_2d(ctx0, cur, model.layers[il].wo->ne[0], n_tokens); cb(cur, "flash_attn_reshaped", il);