Fix (Gemma-4 Vision): Correct KQ mask fill for causal models in non-causal flash-attn mode (#1985)

When llama_set_causal_attn(false) is called on a causal model (e.g.
Gemma-4 during vision image decode), llama_set_inputs took the non-causal
else-branch (designed for pure embedding models).

That path wrote the F16 mask with stride n_tokens instead of n_kv, and iterated batch
indices rather than KV cache cells.

The result was that every image query row beyond the first was
written at the wrong offset, leaving stale -inf values from
previous decodes visible to the GPU kernel. Any conversation
that had built up prior KV mask data would produce all-inf attention scores
for most image tokens, collapsing softmax to NaN and aborting at sampling.

Resolves #1984
This commit is contained in:
gapeleon 2026-06-18 00:52:45 +10:00 committed by GitHub
parent 064d23a6f8
commit 4f220159b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4335,12 +4335,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
auto tim1 = ggml_time_us();
#endif
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn && !lctx.is_encoding) {
// mask_kv_self/n_kv are needed in both the causal branch and the non-causal flash-attn branch below.
const llama_kv_cache & mask_kv_self =
((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)
? lctx.mtp_target_ctx->kv_self
: kv_self;
const int64_t n_kv = mask_kv_self.n;
if (cparams.causal_attn && !lctx.is_encoding) {
const int64_t n_tokens = batch.n_tokens;
@ -4582,6 +4583,48 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
// For causal models running in non-causal mode (e.g., Gemma-4 image decode),
// the flash-attn mask is allocated as [n_kv, n_tokens_pad] and must be filled
// using KV cache cell metadata — not batch-token indices — because image tokens
// occupy cells starting at n_past, not at cell 0.
if (cparams.flash_attn && hparams.causal_attn && !lctx.is_encoding) {
const ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
const ggml_half h_zero = ggml_fp32_to_fp16(0.f);
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];
const llama_pos pos = batch.pos[j];
for (int i = 0; i < n_kv; ++i) {
const bool valid = mask_kv_self.cells[i].has_seq_id(seq_id);
if (data_f16) {
data_f16[j*n_kv + i] = valid ? h_zero : h_inf;
}
if (data_swa_f16) {
ggml_half h = valid ? h_zero : h_inf;
if (h == h_zero) {
if (hparams.n_attn_chunk) {
const llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
h = h_inf;
}
} else if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
h = h_inf;
}
}
data_swa_f16[j*n_kv + i] = h;
}
}
}
const int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
if (n_tokens_padded > n_tokens) {
if (data_f16) {
std::fill(data_f16 + n_tokens*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf);
}
if (data_swa_f16) {
std::fill(data_swa_f16 + n_tokens*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf);
}
}
} else {
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];
@ -4639,6 +4682,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
} // end else (non-flash or non-causal-model path)
#if IK_PRINT_TIMING == 2
auto tim2 = ggml_time_us();
printf("set_inputs(mask2): %d us\n", int(tim2-tim1));