DFlash: enable sliding-window attention for draft models (#2021)

* DFlash: bound intra-block draft tokens to the SWA window

The SWA mask builder applied the sliding-window distance check only to the
cross-context section; the intra-block draft-token loop masked causal-only,
so a draft token could attend to earlier block tokens beyond n_swa. Apply
the same window bound ((j - block_k) < swa_window) in both the F16 and F32
paths so it matches the cross-context section.

Behavior-neutral for dense models: the SWA mask tensor is only allocated
when the model has SWA layers (build_dflash.cpp needs_swa_mask gate), so
for dense targets the changed block is unreachable.

* DFlash: enable sliding-window attention for draft models

DFlash drafts can be trained with sliding-window attention for long context,
but the runtime ignored it: the draft loader never read the window keys and the
converter never emitted them, so SWA-trained drafts always ran full-attention.
Enable it end to end and fix the dormant SWA graph path it exposes:

- convert_hf_to_gguf.py (DFlashDraftModel): emit attention.sliding_window + an
  all-layers sliding_window_pattern when the source config sets use_sliding_window.
- llama-hparams.cpp (LLM_ARCH_DFLASH_DRAFT): read sliding_window + pattern into
  n_swa / swa_layers.
- build_dflash.cpp + llama-dflash.cpp: the SWA mask path had never run; an all-SWA
  draft turned the full kq_mask into a dead graph node the scheduler never backs
  with a buffer, then the input-set wrote it unconditionally (GGML_ASSERT buf!=NULL).
  Create + set each mask only when a layer uses it; derive mask dims from whichever
  mask is live. Dense/mixed drafts are byte-identical.

Validated on gemma-4-26B-A4B at long context (cross_ctx 8176 > window 2048): no
crash, no short-context regression, SWA-on recovers long-context draft acceptance.

* DFlash: derive draft SWA pattern from layer_types

The converter emitted an all-layers SWA pattern ([True]*n_layers). The z-lab
DFlash drafts are sliding-window on every layer except a final full-attention
(global) layer, so this ran that global layer as sliding-window and clipped its
long-context view. Read layer_types and emit the matching per-layer pattern
(sliding_attention -> True), falling back to all-SWA only when layer_types is
absent.

---------

Co-authored-by: Joel Farthing <262452229+joelfarthing@users.noreply.github.com>
This commit is contained in:
Joel Farthing 2026-06-25 02:06:54 -05:00 committed by GitHub
parent 4553cd0059
commit bdf5c081dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 11 deletions

View File

@ -2477,6 +2477,24 @@ class DFlashDraftModel(Qwen3Model):
self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features)
# DFlash drafts may be trained with sliding-window attention (for long-context). When the
# source config enables it, emit the window size + the per-layer SWA pattern so the runtime
# activates the kq_mask_swa path. These drafts are typically all sliding-window except a
# final full-attention (global) layer, so honor layer_types when present; fall back to
# all-SWA only when it is absent. Absent/false use_sliding_window => dense draft (unchanged).
use_sliding_window = self.hparams.get("use_sliding_window")
sliding_window = self.hparams.get("sliding_window")
if use_sliding_window and sliding_window:
n_swa_layers = int(self.hparams.get("num_hidden_layers", self.block_count))
layer_types = self.hparams.get("layer_types")
if layer_types:
swa_pattern = [str(t) == "sliding_attention" for t in layer_types]
else:
swa_pattern = [True] * n_swa_layers
self.gguf_writer.add_sliding_window(int(sliding_window))
self.gguf_writer.add_sliding_window_pattern(swa_pattern)
logger.info("DFlashDraftModel: sliding_window=%d, SWA pattern=%s", int(sliding_window), swa_pattern)
logger.info(
"DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s",
block_size,

View File

@ -171,16 +171,34 @@ ggml_cgraph * llm_build_context::build_dflash() {
}();
const ggml_type mask_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
ggml_set_input(lctx.dflash.inputs.kq_mask);
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
// The full (non-SWA) mask is only consumed by non-SWA layers. For an all-SWA draft every layer
// uses kq_mask_swa, leaving the full mask a dead graph node that the scheduler never backs with a
// buffer (and the unconditional input-set then asserts buf!=NULL). So create each mask only when
// some layer uses it: full mask iff any non-SWA layer; swa mask iff needs_swa_mask.
const bool needs_full_mask = !needs_swa_mask || [&]() {
for (int il = 0; il < n_layer; ++il) {
if (!hparams.swa_layers[il]) {
return true;
}
}
return false;
}();
lctx.dflash.inputs.kq_mask = nullptr;
lctx.dflash.kv.kq_mask_tensor = nullptr;
ggml_tensor * dflash_kq_mask_full = nullptr;
if (needs_full_mask) {
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
ggml_set_input(lctx.dflash.inputs.kq_mask);
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
dflash_kq_mask_full = lctx.dflash.inputs.kq_mask;
}
lctx.dflash.kv.draft_tail_rows_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(lctx.dflash.kv.draft_tail_rows_tensor);
cb(lctx.dflash.kv.draft_tail_rows_tensor, "dflash_draft_tail_rows", -1);
ggml_tensor * dflash_kq_mask_full = lctx.dflash.inputs.kq_mask;
ggml_tensor * dflash_kq_mask_swa = nullptr;
lctx.dflash.inputs.kq_mask_swa = nullptr;
lctx.dflash.kv.kq_mask_swa_tensor = nullptr;

View File

@ -327,7 +327,10 @@ bool llama_prepare_dflash_graph_inputs(
ggml_tensor * kq_mask = lctx.dflash.kv.kq_mask_tensor;
ggml_tensor * kq_mask_swa = lctx.dflash.kv.kq_mask_swa_tensor;
if (kq_mask == nullptr) {
// An all-SWA draft has no full mask; an all-full draft has no SWA mask. Both masks share the
// same dimensions, so use whichever one is live to derive shape.
ggml_tensor * mask_dims = kq_mask != nullptr ? kq_mask : kq_mask_swa;
if (mask_dims == nullptr) {
LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__);
return false;
}
@ -351,8 +354,8 @@ bool llama_prepare_dflash_graph_inputs(
const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows;
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
const int32_t graph_cross_ctx = (int32_t) lctx.dflash.kv.cache_pos.size();
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
const int32_t n_mask_tokens = (int32_t) mask_dims->ne[1];
const int32_t n_kv_total = (int32_t) mask_dims->ne[0];
ggml_tensor * draft_tail_rows = lctx.dflash.kv.draft_tail_rows_tensor;
if (graph_cross_ctx != cross_ctx) {
@ -559,7 +562,10 @@ bool llama_prepare_dflash_graph_inputs(
ggml_backend_tensor_set(draft_tail_rows, draft_tail_rows_data.data(), 0, ggml_nbytes(draft_tail_rows));
const size_t mask_elems = (size_t) n_kv_total * (size_t) n_mask_tokens;
if (kq_mask->type == GGML_TYPE_F16) {
if (kq_mask == nullptr) {
// all-SWA draft: the full mask was not created (no non-SWA layer consumes it); only the
// SWA mask below is populated.
} else if (kq_mask->type == GGML_TYPE_F16) {
const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY);
const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f);
std::vector<ggml_fp16_t> mask_f16(mask_elems, h_inf);
@ -613,7 +619,10 @@ bool llama_prepare_dflash_graph_inputs(
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
const int32_t block_k = k - cross_ctx;
if (block_k <= (int32_t) j) {
// intra-block draft tokens are contiguous from draft_pos_base, so the
// SWA distance is (j - block_k); apply the same window bound as the
// cross-context section above (causal AND within n_swa).
if (block_k <= (int32_t) j && ((int32_t) j - block_k) < swa_window) {
row[k] = h_zero;
}
}
@ -637,7 +646,10 @@ bool llama_prepare_dflash_graph_inputs(
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
const int32_t block_k = k - cross_ctx;
if (block_k <= (int32_t) j) {
// intra-block draft tokens are contiguous from draft_pos_base, so the
// SWA distance is (j - block_k); apply the same window bound as the
// cross-context section above (causal AND within n_swa).
if (block_k <= (int32_t) j && ((int32_t) j - block_k) < swa_window) {
row[k] = 0.0f;
}
}

View File

@ -903,6 +903,11 @@ void llm_load_hparams(
ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false);
ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false);
load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false);
// DFlash drafts may be trained with sliding-window attention (for long-context).
// Read the window + per-layer pattern so the SWA mask path activates; absent keys
// leave n_swa=0 / swa_layers all-zero (dense behavior, unchanged).
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer, false);
validate_dflash_hparams(hparams, model.arch);
hparams.n_layer_kv_from_start = hparams.n_layer;