From bdf5c081dcda2d746c3eaebc686a0352974265da Mon Sep 17 00:00:00 2001 From: Joel Farthing Date: Thu, 25 Jun 2026 02:06:54 -0500 Subject: [PATCH] DFlash: enable sliding-window attention for draft models (#2021) * DFlash: bound intra-block draft tokens to the SWA window The SWA mask builder applied the sliding-window distance check only to the cross-context section; the intra-block draft-token loop masked causal-only, so a draft token could attend to earlier block tokens beyond n_swa. Apply the same window bound ((j - block_k) < swa_window) in both the F16 and F32 paths so it matches the cross-context section. Behavior-neutral for dense models: the SWA mask tensor is only allocated when the model has SWA layers (build_dflash.cpp needs_swa_mask gate), so for dense targets the changed block is unreachable. * DFlash: enable sliding-window attention for draft models DFlash drafts can be trained with sliding-window attention for long context, but the runtime ignored it: the draft loader never read the window keys and the converter never emitted them, so SWA-trained drafts always ran full-attention. Enable it end to end and fix the dormant SWA graph path it exposes: - convert_hf_to_gguf.py (DFlashDraftModel): emit attention.sliding_window + an all-layers sliding_window_pattern when the source config sets use_sliding_window. - llama-hparams.cpp (LLM_ARCH_DFLASH_DRAFT): read sliding_window + pattern into n_swa / swa_layers. - build_dflash.cpp + llama-dflash.cpp: the SWA mask path had never run; an all-SWA draft turned the full kq_mask into a dead graph node the scheduler never backs with a buffer, then the input-set wrote it unconditionally (GGML_ASSERT buf!=NULL). Create + set each mask only when a layer uses it; derive mask dims from whichever mask is live. Dense/mixed drafts are byte-identical. Validated on gemma-4-26B-A4B at long context (cross_ctx 8176 > window 2048): no crash, no short-context regression, SWA-on recovers long-context draft acceptance. * DFlash: derive draft SWA pattern from layer_types The converter emitted an all-layers SWA pattern ([True]*n_layers). The z-lab DFlash drafts are sliding-window on every layer except a final full-attention (global) layer, so this ran that global layer as sliding-window and clipped its long-context view. Read layer_types and emit the matching per-layer pattern (sliding_attention -> True), falling back to all-SWA only when layer_types is absent. --------- Co-authored-by: Joel Farthing <262452229+joelfarthing@users.noreply.github.com> --- convert_hf_to_gguf.py | 18 ++++++++++++++++++ src/graphs/build_dflash.cpp | 28 +++++++++++++++++++++++----- src/llama-dflash.cpp | 24 ++++++++++++++++++------ src/llama-hparams.cpp | 5 +++++ 4 files changed, 64 insertions(+), 11 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cfe5bcd5..820c9607 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2477,6 +2477,24 @@ class DFlashDraftModel(Qwen3Model): self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features) + # DFlash drafts may be trained with sliding-window attention (for long-context). When the + # source config enables it, emit the window size + the per-layer SWA pattern so the runtime + # activates the kq_mask_swa path. These drafts are typically all sliding-window except a + # final full-attention (global) layer, so honor layer_types when present; fall back to + # all-SWA only when it is absent. Absent/false use_sliding_window => dense draft (unchanged). + use_sliding_window = self.hparams.get("use_sliding_window") + sliding_window = self.hparams.get("sliding_window") + if use_sliding_window and sliding_window: + n_swa_layers = int(self.hparams.get("num_hidden_layers", self.block_count)) + layer_types = self.hparams.get("layer_types") + if layer_types: + swa_pattern = [str(t) == "sliding_attention" for t in layer_types] + else: + swa_pattern = [True] * n_swa_layers + self.gguf_writer.add_sliding_window(int(sliding_window)) + self.gguf_writer.add_sliding_window_pattern(swa_pattern) + logger.info("DFlashDraftModel: sliding_window=%d, SWA pattern=%s", int(sliding_window), swa_pattern) + logger.info( "DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s", block_size, diff --git a/src/graphs/build_dflash.cpp b/src/graphs/build_dflash.cpp index 05002027..ed867e10 100644 --- a/src/graphs/build_dflash.cpp +++ b/src/graphs/build_dflash.cpp @@ -171,16 +171,34 @@ ggml_cgraph * llm_build_context::build_dflash() { }(); const ggml_type mask_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; - lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask; - ggml_set_input(lctx.dflash.inputs.kq_mask); - cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1); + // The full (non-SWA) mask is only consumed by non-SWA layers. For an all-SWA draft every layer + // uses kq_mask_swa, leaving the full mask a dead graph node that the scheduler never backs with a + // buffer (and the unconditional input-set then asserts buf!=NULL). So create each mask only when + // some layer uses it: full mask iff any non-SWA layer; swa mask iff needs_swa_mask. + const bool needs_full_mask = !needs_swa_mask || [&]() { + for (int il = 0; il < n_layer; ++il) { + if (!hparams.swa_layers[il]) { + return true; + } + } + return false; + }(); + + lctx.dflash.inputs.kq_mask = nullptr; + lctx.dflash.kv.kq_mask_tensor = nullptr; + ggml_tensor * dflash_kq_mask_full = nullptr; + if (needs_full_mask) { + lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask; + ggml_set_input(lctx.dflash.inputs.kq_mask); + cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1); + dflash_kq_mask_full = lctx.dflash.inputs.kq_mask; + } lctx.dflash.kv.draft_tail_rows_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(lctx.dflash.kv.draft_tail_rows_tensor); cb(lctx.dflash.kv.draft_tail_rows_tensor, "dflash_draft_tail_rows", -1); - ggml_tensor * dflash_kq_mask_full = lctx.dflash.inputs.kq_mask; ggml_tensor * dflash_kq_mask_swa = nullptr; lctx.dflash.inputs.kq_mask_swa = nullptr; lctx.dflash.kv.kq_mask_swa_tensor = nullptr; diff --git a/src/llama-dflash.cpp b/src/llama-dflash.cpp index 9ccddeec..f419aab0 100644 --- a/src/llama-dflash.cpp +++ b/src/llama-dflash.cpp @@ -327,7 +327,10 @@ bool llama_prepare_dflash_graph_inputs( ggml_tensor * kq_mask = lctx.dflash.kv.kq_mask_tensor; ggml_tensor * kq_mask_swa = lctx.dflash.kv.kq_mask_swa_tensor; - if (kq_mask == nullptr) { + // An all-SWA draft has no full mask; an all-full draft has no SWA mask. Both masks share the + // same dimensions, so use whichever one is live to derive shape. + ggml_tensor * mask_dims = kq_mask != nullptr ? kq_mask : kq_mask_swa; + if (mask_dims == nullptr) { LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__); return false; } @@ -351,8 +354,8 @@ bool llama_prepare_dflash_graph_inputs( const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows; const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features; const int32_t graph_cross_ctx = (int32_t) lctx.dflash.kv.cache_pos.size(); - const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1]; - const int32_t n_kv_total = (int32_t) kq_mask->ne[0]; + const int32_t n_mask_tokens = (int32_t) mask_dims->ne[1]; + const int32_t n_kv_total = (int32_t) mask_dims->ne[0]; ggml_tensor * draft_tail_rows = lctx.dflash.kv.draft_tail_rows_tensor; if (graph_cross_ctx != cross_ctx) { @@ -559,7 +562,10 @@ bool llama_prepare_dflash_graph_inputs( ggml_backend_tensor_set(draft_tail_rows, draft_tail_rows_data.data(), 0, ggml_nbytes(draft_tail_rows)); const size_t mask_elems = (size_t) n_kv_total * (size_t) n_mask_tokens; - if (kq_mask->type == GGML_TYPE_F16) { + if (kq_mask == nullptr) { + // all-SWA draft: the full mask was not created (no non-SWA layer consumes it); only the + // SWA mask below is populated. + } else if (kq_mask->type == GGML_TYPE_F16) { const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY); const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f); std::vector mask_f16(mask_elems, h_inf); @@ -613,7 +619,10 @@ bool llama_prepare_dflash_graph_inputs( for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { const int32_t block_k = k - cross_ctx; - if (block_k <= (int32_t) j) { + // intra-block draft tokens are contiguous from draft_pos_base, so the + // SWA distance is (j - block_k); apply the same window bound as the + // cross-context section above (causal AND within n_swa). + if (block_k <= (int32_t) j && ((int32_t) j - block_k) < swa_window) { row[k] = h_zero; } } @@ -637,7 +646,10 @@ bool llama_prepare_dflash_graph_inputs( for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) { const int32_t block_k = k - cross_ctx; - if (block_k <= (int32_t) j) { + // intra-block draft tokens are contiguous from draft_pos_base, so the + // SWA distance is (j - block_k); apply the same window bound as the + // cross-context section above (causal AND within n_swa). + if (block_k <= (int32_t) j && ((int32_t) j - block_k) < swa_window) { row[k] = 0.0f; } } diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index fbca1a4b..bfb93283 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -903,6 +903,11 @@ void llm_load_hparams( ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false); ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false); load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false); + // DFlash drafts may be trained with sliding-window attention (for long-context). + // Read the window + per-layer pattern so the SWA mask path activates; absent keys + // leave n_swa=0 / swa_layers all-zero (dense behavior, unchanged). + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer, false); validate_dflash_hparams(hparams, model.arch); hparams.n_layer_kv_from_start = hparams.n_layer;