graph : ensure DS32 kq_mask_lid is F32 (#23864)

This commit is contained in:
Sigbjørn Skjæret 2026-05-29 19:55:14 +02:00 committed by GitHub
parent b5f52280fb
commit 764f1e64a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 7 deletions

View File

@ -2656,14 +2656,18 @@ llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const {
inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams);
inp->self_kq_mask_mla_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_mla, GGML_TYPE_F16) : inp->self_kq_mask_mla;
inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla;
}
{
inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams);
inp->self_kq_mask_lid_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_lid, GGML_TYPE_F16) : inp->self_kq_mask_lid;
// ensure F32 mask
auto cparams_copy = cparams;
cparams_copy.flash_attn = false;
inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy);
inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid;
inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0);
}

View File

@ -399,10 +399,10 @@ public:
ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch]
ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_mla = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_k_rot_lid = nullptr;