mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
graph : ensure DS32 kq_mask_lid is F32 (#23864)
This commit is contained in:
parent
b5f52280fb
commit
764f1e64a1
@ -2656,14 +2656,18 @@ llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const {
|
||||
inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams);
|
||||
inp->self_kq_mask_mla_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_mla, GGML_TYPE_F16) : inp->self_kq_mask_mla;
|
||||
inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla;
|
||||
}
|
||||
|
||||
{
|
||||
inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams);
|
||||
inp->self_kq_mask_lid_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_lid, GGML_TYPE_F16) : inp->self_kq_mask_lid;
|
||||
// ensure F32 mask
|
||||
auto cparams_copy = cparams;
|
||||
cparams_copy.flash_attn = false;
|
||||
|
||||
inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy);
|
||||
inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid;
|
||||
|
||||
inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0);
|
||||
}
|
||||
|
||||
@ -399,10 +399,10 @@ public:
|
||||
ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask_mla = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
ggml_tensor * self_k_rot_lid = nullptr;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user