diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index b6292fcd..f1776804 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1173,12 +1173,16 @@ struct FlashQKV { S += expf(s - fms.M[j]); } } - GGML_ASSERT(S > 0); - auto norm = F16::set1(1/S); - //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); - for (int i = 0; i < D/F16::block_size; ++i) { - auto r = F16::load(R + F16::block_size*i); - F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); + if (S > 0) { + auto norm = F16::set1(1/S); + for (int i = 0; i < D/F16::block_size; ++i) { + auto r = F16::load(R + F16::block_size*i); + F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); + } + } else { + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::zero()); + } } }