diff --git a/src/models/plamo2.cpp b/src/models/plamo2.cpp index b93cf48bc5..0b81513c36 100644 --- a/src/models/plamo2.cpp +++ b/src/models/plamo2.cpp @@ -11,6 +11,10 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } @@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(d_inner % n_heads == 0); GGML_ASSERT(n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);