mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Fix (Gemma-4 Vision): Correct KQ mask fill for causal models in non-causal flash-attn mode (#1985)
When llama_set_causal_attn(false) is called on a causal model (e.g. Gemma-4 during vision image decode), llama_set_inputs took the non-causal else-branch (designed for pure embedding models). That path wrote the F16 mask with stride n_tokens instead of n_kv, and iterated batch indices rather than KV cache cells. The result was that every image query row beyond the first was written at the wrong offset, leaving stale -inf values from previous decodes visible to the GPU kernel. Any conversation that had built up prior KV mask data would produce all-inf attention scores for most image tokens, collapsing softmax to NaN and aborting at sampling. Resolves #1984
This commit is contained in:
parent
064d23a6f8
commit
4f220159b8
@ -4335,12 +4335,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
auto tim1 = ggml_time_us();
|
auto tim1 = ggml_time_us();
|
||||||
#endif
|
#endif
|
||||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||||
|
// mask_kv_self/n_kv are needed in both the causal branch and the non-causal flash-attn branch below.
|
||||||
|
const llama_kv_cache & mask_kv_self =
|
||||||
|
((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)
|
||||||
|
? lctx.mtp_target_ctx->kv_self
|
||||||
|
: kv_self;
|
||||||
|
const int64_t n_kv = mask_kv_self.n;
|
||||||
if (cparams.causal_attn && !lctx.is_encoding) {
|
if (cparams.causal_attn && !lctx.is_encoding) {
|
||||||
const llama_kv_cache & mask_kv_self =
|
|
||||||
((lctx.model.arch == LLM_ARCH_GEMMA4_MTP || lctx.model.arch == LLM_ARCH_GEMMA4_ASSISTANT) && lctx.mtp_target_ctx != nullptr)
|
|
||||||
? lctx.mtp_target_ctx->kv_self
|
|
||||||
: kv_self;
|
|
||||||
const int64_t n_kv = mask_kv_self.n;
|
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
|
||||||
@ -4582,6 +4583,48 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For causal models running in non-causal mode (e.g., Gemma-4 image decode),
|
||||||
|
// the flash-attn mask is allocated as [n_kv, n_tokens_pad] and must be filled
|
||||||
|
// using KV cache cell metadata — not batch-token indices — because image tokens
|
||||||
|
// occupy cells starting at n_past, not at cell 0.
|
||||||
|
if (cparams.flash_attn && hparams.causal_attn && !lctx.is_encoding) {
|
||||||
|
const ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||||
|
const ggml_half h_zero = ggml_fp32_to_fp16(0.f);
|
||||||
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
const llama_pos pos = batch.pos[j];
|
||||||
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
|
const bool valid = mask_kv_self.cells[i].has_seq_id(seq_id);
|
||||||
|
if (data_f16) {
|
||||||
|
data_f16[j*n_kv + i] = valid ? h_zero : h_inf;
|
||||||
|
}
|
||||||
|
if (data_swa_f16) {
|
||||||
|
ggml_half h = valid ? h_zero : h_inf;
|
||||||
|
if (h == h_zero) {
|
||||||
|
if (hparams.n_attn_chunk) {
|
||||||
|
const llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||||
|
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||||
|
h = h_inf;
|
||||||
|
}
|
||||||
|
} else if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||||
|
h = h_inf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data_swa_f16[j*n_kv + i] = h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
|
||||||
|
if (n_tokens_padded > n_tokens) {
|
||||||
|
if (data_f16) {
|
||||||
|
std::fill(data_f16 + n_tokens*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf);
|
||||||
|
}
|
||||||
|
if (data_swa_f16) {
|
||||||
|
std::fill(data_swa_f16 + n_tokens*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
@ -4639,6 +4682,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // end else (non-flash or non-causal-model path)
|
||||||
#if IK_PRINT_TIMING == 2
|
#if IK_PRINT_TIMING == 2
|
||||||
auto tim2 = ggml_time_us();
|
auto tim2 = ggml_time_us();
|
||||||
printf("set_inputs(mask2): %d us\n", int(tim2-tim1));
|
printf("set_inputs(mask2): %d us\n", int(tim2-tim1));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user