mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix (Gemma-4 Vision): Correct KQ mask fill for causal models in non-causal flash-attn mode (#1985)
When llama_set_causal_attn(false) is called on a causal model (e.g. Gemma-4 during vision image decode), llama_set_inputs took the non-causal else-branch (designed for pure embedding models). That path wrote the F16 mask with stride n_tokens instead of n_kv, and iterated batch indices rather than KV cache cells. The result was that every image query row beyond the first was written at the wrong offset, leaving stale -inf values from previous decodes visible to the GPU kernel. Any conversation that had built up prior KV mask data would produce all-inf attention scores for most image tokens, collapsing softmax to NaN and aborting at sampling. Resolves #1984
This commit is contained in:
parent
064d23a6f8
commit
4f220159b8
@ -4335,12 +4335,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
// mask_kv_self/n_kv are needed in both the causal branch and the non-causal flash-attn branch below.
|
||||
const llama_kv_cache & mask_kv_self =
|
||||
((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)
|
||||
? lctx.mtp_target_ctx->kv_self
|
||||
: kv_self;
|
||||
const int64_t n_kv = mask_kv_self.n;
|
||||
if (cparams.causal_attn && !lctx.is_encoding) {
|
||||
const llama_kv_cache & mask_kv_self =
|
||||
((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)
|
||||
? lctx.mtp_target_ctx->kv_self
|
||||
: kv_self;
|
||||
const int64_t n_kv = mask_kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
||||
@ -4582,6 +4583,48 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
// For causal models running in non-causal mode (e.g., Gemma-4 image decode),
|
||||
// the flash-attn mask is allocated as [n_kv, n_tokens_pad] and must be filled
|
||||
// using KV cache cell metadata — not batch-token indices — because image tokens
|
||||
// occupy cells starting at n_past, not at cell 0.
|
||||
if (cparams.flash_attn && hparams.causal_attn && !lctx.is_encoding) {
|
||||
const ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
const ggml_half h_zero = ggml_fp32_to_fp16(0.f);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
const llama_pos pos = batch.pos[j];
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const bool valid = mask_kv_self.cells[i].has_seq_id(seq_id);
|
||||
if (data_f16) {
|
||||
data_f16[j*n_kv + i] = valid ? h_zero : h_inf;
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
ggml_half h = valid ? h_zero : h_inf;
|
||||
if (h == h_zero) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
const llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
h = h_inf;
|
||||
}
|
||||
} else if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
h = h_inf;
|
||||
}
|
||||
}
|
||||
data_swa_f16[j*n_kv + i] = h;
|
||||
}
|
||||
}
|
||||
}
|
||||
const int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
|
||||
if (n_tokens_padded > n_tokens) {
|
||||
if (data_f16) {
|
||||
std::fill(data_f16 + n_tokens*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf);
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
std::fill(data_swa_f16 + n_tokens*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
@ -4639,6 +4682,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end else (non-flash or non-causal-model path)
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(mask2): %d us\n", int(tim2-tim1));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user