From 4f220159b800828718dd778ca0d3600343588504 Mon Sep 17 00:00:00 2001 From: gapeleon <191471103+gapeleon@users.noreply.github.com> Date: Thu, 18 Jun 2026 00:52:45 +1000 Subject: [PATCH] Fix (Gemma-4 Vision): Correct KQ mask fill for causal models in non-causal flash-attn mode (#1985) When llama_set_causal_attn(false) is called on a causal model (e.g. Gemma-4 during vision image decode), llama_set_inputs took the non-causal else-branch (designed for pure embedding models). That path wrote the F16 mask with stride n_tokens instead of n_kv, and iterated batch indices rather than KV cache cells. The result was that every image query row beyond the first was written at the wrong offset, leaving stale -inf values from previous decodes visible to the GPU kernel. Any conversation that had built up prior KV mask data would produce all-inf attention scores for most image tokens, collapsing softmax to NaN and aborting at sampling. Resolves #1984 --- src/llama.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4836f998..0a20563f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4335,12 +4335,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { auto tim1 = ggml_time_us(); #endif // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + // mask_kv_self/n_kv are needed in both the causal branch and the non-causal flash-attn branch below. + const llama_kv_cache & mask_kv_self = + ((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr) + ? lctx.mtp_target_ctx->kv_self + : kv_self; + const int64_t n_kv = mask_kv_self.n; if (cparams.causal_attn && !lctx.is_encoding) { - const llama_kv_cache & mask_kv_self = - ((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr) - ? lctx.mtp_target_ctx->kv_self - : kv_self; - const int64_t n_kv = mask_kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -4582,6 +4583,48 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + // For causal models running in non-causal mode (e.g., Gemma-4 image decode), + // the flash-attn mask is allocated as [n_kv, n_tokens_pad] and must be filled + // using KV cache cell metadata — not batch-token indices — because image tokens + // occupy cells starting at n_past, not at cell 0. + if (cparams.flash_attn && hparams.causal_attn && !lctx.is_encoding) { + const ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + const ggml_half h_zero = ggml_fp32_to_fp16(0.f); + for (int j = 0; j < n_tokens; ++j) { + const llama_seq_id seq_id = batch.seq_id[j][0]; + const llama_pos pos = batch.pos[j]; + for (int i = 0; i < n_kv; ++i) { + const bool valid = mask_kv_self.cells[i].has_seq_id(seq_id); + if (data_f16) { + data_f16[j*n_kv + i] = valid ? h_zero : h_inf; + } + if (data_swa_f16) { + ggml_half h = valid ? h_zero : h_inf; + if (h == h_zero) { + if (hparams.n_attn_chunk) { + const llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + h = h_inf; + } + } else if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + h = h_inf; + } + } + data_swa_f16[j*n_kv + i] = h; + } + } + } + const int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_tokens_padded > n_tokens) { + if (data_f16) { + std::fill(data_f16 + n_tokens*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf); + } + if (data_swa_f16) { + std::fill(data_swa_f16 + n_tokens*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf); + } + } + } else { + for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_seq_id seq_id = batch.seq_id[j][0]; @@ -4639,6 +4682,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } + + } // end else (non-flash or non-causal-model path) #if IK_PRINT_TIMING == 2 auto tim2 = ggml_time_us(); printf("set_inputs(mask2): %d us\n", int(tim2-tim1));