mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
DFlash: enable sliding-window attention for draft models (#2021)
* DFlash: bound intra-block draft tokens to the SWA window The SWA mask builder applied the sliding-window distance check only to the cross-context section; the intra-block draft-token loop masked causal-only, so a draft token could attend to earlier block tokens beyond n_swa. Apply the same window bound ((j - block_k) < swa_window) in both the F16 and F32 paths so it matches the cross-context section. Behavior-neutral for dense models: the SWA mask tensor is only allocated when the model has SWA layers (build_dflash.cpp needs_swa_mask gate), so for dense targets the changed block is unreachable. * DFlash: enable sliding-window attention for draft models DFlash drafts can be trained with sliding-window attention for long context, but the runtime ignored it: the draft loader never read the window keys and the converter never emitted them, so SWA-trained drafts always ran full-attention. Enable it end to end and fix the dormant SWA graph path it exposes: - convert_hf_to_gguf.py (DFlashDraftModel): emit attention.sliding_window + an all-layers sliding_window_pattern when the source config sets use_sliding_window. - llama-hparams.cpp (LLM_ARCH_DFLASH_DRAFT): read sliding_window + pattern into n_swa / swa_layers. - build_dflash.cpp + llama-dflash.cpp: the SWA mask path had never run; an all-SWA draft turned the full kq_mask into a dead graph node the scheduler never backs with a buffer, then the input-set wrote it unconditionally (GGML_ASSERT buf!=NULL). Create + set each mask only when a layer uses it; derive mask dims from whichever mask is live. Dense/mixed drafts are byte-identical. Validated on gemma-4-26B-A4B at long context (cross_ctx 8176 > window 2048): no crash, no short-context regression, SWA-on recovers long-context draft acceptance. * DFlash: derive draft SWA pattern from layer_types The converter emitted an all-layers SWA pattern ([True]*n_layers). The z-lab DFlash drafts are sliding-window on every layer except a final full-attention (global) layer, so this ran that global layer as sliding-window and clipped its long-context view. Read layer_types and emit the matching per-layer pattern (sliding_attention -> True), falling back to all-SWA only when layer_types is absent. --------- Co-authored-by: Joel Farthing <262452229+joelfarthing@users.noreply.github.com>
This commit is contained in:
parent
4553cd0059
commit
bdf5c081dc
@ -2477,6 +2477,24 @@ class DFlashDraftModel(Qwen3Model):
|
||||
|
||||
self.gguf_writer.add_uint32(f"{arch}.dflash.n_target_features", n_target_features)
|
||||
|
||||
# DFlash drafts may be trained with sliding-window attention (for long-context). When the
|
||||
# source config enables it, emit the window size + the per-layer SWA pattern so the runtime
|
||||
# activates the kq_mask_swa path. These drafts are typically all sliding-window except a
|
||||
# final full-attention (global) layer, so honor layer_types when present; fall back to
|
||||
# all-SWA only when it is absent. Absent/false use_sliding_window => dense draft (unchanged).
|
||||
use_sliding_window = self.hparams.get("use_sliding_window")
|
||||
sliding_window = self.hparams.get("sliding_window")
|
||||
if use_sliding_window and sliding_window:
|
||||
n_swa_layers = int(self.hparams.get("num_hidden_layers", self.block_count))
|
||||
layer_types = self.hparams.get("layer_types")
|
||||
if layer_types:
|
||||
swa_pattern = [str(t) == "sliding_attention" for t in layer_types]
|
||||
else:
|
||||
swa_pattern = [True] * n_swa_layers
|
||||
self.gguf_writer.add_sliding_window(int(sliding_window))
|
||||
self.gguf_writer.add_sliding_window_pattern(swa_pattern)
|
||||
logger.info("DFlashDraftModel: sliding_window=%d, SWA pattern=%s", int(sliding_window), swa_pattern)
|
||||
|
||||
logger.info(
|
||||
"DFlashDraftModel metadata: block_size=%s mask_token_id=%s target_layer_ids=%s n_target_features=%s",
|
||||
block_size,
|
||||
|
||||
@ -171,16 +171,34 @@ ggml_cgraph * llm_build_context::build_dflash() {
|
||||
}();
|
||||
const ggml_type mask_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
|
||||
ggml_set_input(lctx.dflash.inputs.kq_mask);
|
||||
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
|
||||
// The full (non-SWA) mask is only consumed by non-SWA layers. For an all-SWA draft every layer
|
||||
// uses kq_mask_swa, leaving the full mask a dead graph node that the scheduler never backs with a
|
||||
// buffer (and the unconditional input-set then asserts buf!=NULL). So create each mask only when
|
||||
// some layer uses it: full mask iff any non-SWA layer; swa mask iff needs_swa_mask.
|
||||
const bool needs_full_mask = !needs_swa_mask || [&]() {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (!hparams.swa_layers[il]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
|
||||
lctx.dflash.inputs.kq_mask = nullptr;
|
||||
lctx.dflash.kv.kq_mask_tensor = nullptr;
|
||||
ggml_tensor * dflash_kq_mask_full = nullptr;
|
||||
if (needs_full_mask) {
|
||||
lctx.dflash.inputs.kq_mask = ggml_new_tensor_2d(ctx0, mask_type, n_kv_total, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.dflash.kv.kq_mask_tensor = lctx.dflash.inputs.kq_mask;
|
||||
ggml_set_input(lctx.dflash.inputs.kq_mask);
|
||||
cb(lctx.dflash.inputs.kq_mask, "dflash_kq_mask", -1);
|
||||
dflash_kq_mask_full = lctx.dflash.inputs.kq_mask;
|
||||
}
|
||||
|
||||
lctx.dflash.kv.draft_tail_rows_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(lctx.dflash.kv.draft_tail_rows_tensor);
|
||||
cb(lctx.dflash.kv.draft_tail_rows_tensor, "dflash_draft_tail_rows", -1);
|
||||
|
||||
ggml_tensor * dflash_kq_mask_full = lctx.dflash.inputs.kq_mask;
|
||||
ggml_tensor * dflash_kq_mask_swa = nullptr;
|
||||
lctx.dflash.inputs.kq_mask_swa = nullptr;
|
||||
lctx.dflash.kv.kq_mask_swa_tensor = nullptr;
|
||||
|
||||
@ -327,7 +327,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
ggml_tensor * kq_mask = lctx.dflash.kv.kq_mask_tensor;
|
||||
ggml_tensor * kq_mask_swa = lctx.dflash.kv.kq_mask_swa_tensor;
|
||||
|
||||
if (kq_mask == nullptr) {
|
||||
// An all-SWA draft has no full mask; an all-full draft has no SWA mask. Both masks share the
|
||||
// same dimensions, so use whichever one is live to derive shape.
|
||||
ggml_tensor * mask_dims = kq_mask != nullptr ? kq_mask : kq_mask_swa;
|
||||
if (mask_dims == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: DFlash graph inputs are not initialized\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@ -351,8 +354,8 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
const int32_t append_rows_available = lctx.dflash.target.append_features_n_rows;
|
||||
const int32_t width = (int32_t) lctx.model.hparams.dflash_n_target_features;
|
||||
const int32_t graph_cross_ctx = (int32_t) lctx.dflash.kv.cache_pos.size();
|
||||
const int32_t n_mask_tokens = (int32_t) kq_mask->ne[1];
|
||||
const int32_t n_kv_total = (int32_t) kq_mask->ne[0];
|
||||
const int32_t n_mask_tokens = (int32_t) mask_dims->ne[1];
|
||||
const int32_t n_kv_total = (int32_t) mask_dims->ne[0];
|
||||
ggml_tensor * draft_tail_rows = lctx.dflash.kv.draft_tail_rows_tensor;
|
||||
|
||||
if (graph_cross_ctx != cross_ctx) {
|
||||
@ -559,7 +562,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
ggml_backend_tensor_set(draft_tail_rows, draft_tail_rows_data.data(), 0, ggml_nbytes(draft_tail_rows));
|
||||
|
||||
const size_t mask_elems = (size_t) n_kv_total * (size_t) n_mask_tokens;
|
||||
if (kq_mask->type == GGML_TYPE_F16) {
|
||||
if (kq_mask == nullptr) {
|
||||
// all-SWA draft: the full mask was not created (no non-SWA layer consumes it); only the
|
||||
// SWA mask below is populated.
|
||||
} else if (kq_mask->type == GGML_TYPE_F16) {
|
||||
const ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
const ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f);
|
||||
std::vector<ggml_fp16_t> mask_f16(mask_elems, h_inf);
|
||||
@ -613,7 +619,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
|
||||
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
|
||||
const int32_t block_k = k - cross_ctx;
|
||||
if (block_k <= (int32_t) j) {
|
||||
// intra-block draft tokens are contiguous from draft_pos_base, so the
|
||||
// SWA distance is (j - block_k); apply the same window bound as the
|
||||
// cross-context section above (causal AND within n_swa).
|
||||
if (block_k <= (int32_t) j && ((int32_t) j - block_k) < swa_window) {
|
||||
row[k] = h_zero;
|
||||
}
|
||||
}
|
||||
@ -637,7 +646,10 @@ bool llama_prepare_dflash_graph_inputs(
|
||||
|
||||
for (int32_t k = cross_ctx; k < cross_ctx + (int32_t) n_tokens; ++k) {
|
||||
const int32_t block_k = k - cross_ctx;
|
||||
if (block_k <= (int32_t) j) {
|
||||
// intra-block draft tokens are contiguous from draft_pos_base, so the
|
||||
// SWA distance is (j - block_k); apply the same window bound as the
|
||||
// cross-context section above (causal AND within n_swa).
|
||||
if (block_k <= (int32_t) j && ((int32_t) j - block_k) < swa_window) {
|
||||
row[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@ -903,6 +903,11 @@ void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false);
|
||||
ml.get_key(LLM_KV_DFLASH_N_TARGET_FEATURES, hparams.dflash_n_target_features, false);
|
||||
load_dflash_target_layer_ids(ml, LLM_KV(model.arch)(LLM_KV_DFLASH_TARGET_LAYER_IDS), hparams, false);
|
||||
// DFlash drafts may be trained with sliding-window attention (for long-context).
|
||||
// Read the window + per-layer pattern so the SWA mask path activates; absent keys
|
||||
// leave n_swa=0 / swa_layers all-zero (dense behavior, unchanged).
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer, false);
|
||||
validate_dflash_hparams(hparams, model.arch);
|
||||
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user