models : fix plamo2 attention_key/value_length regression (#24317)

This commit is contained in:
Sigbjørn Skjæret 2026-06-09 09:26:44 +02:00 committed by GitHub
parent fd3271e0b4
commit f0152efe40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,10 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
for (uint32_t i = 0; i < hparams.n_layer(); ++i) {
hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0;
}
@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % n_heads == 0);
GGML_ASSERT(n_group == 0);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);