diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 442335240f..eaf0494aab 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -767,8 +767,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_PRE_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_POST_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PRE_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_POST_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so // the model loader doesn't fault on the block index. diff --git a/src/llama-arch.h b/src/llama-arch.h index 82d97dc200..1e71c19be4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -557,8 +557,8 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, - LLM_TENSOR_NEXTN_PRE_PROJ, - LLM_TENSOR_NEXTN_POST_PROJ, + LLM_TENSOR_NEXTN_PRE_PROJ, // TODO: rename to PROJ_PRE + LLM_TENSOR_NEXTN_POST_PROJ, // TODO: rename to PROJ_POST LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/models/gemma4-assistant.cpp b/src/models/gemma4-assistant.cpp index 10f69fa3d8..8c274e0cbd 100644 --- a/src/models/gemma4-assistant.cpp +++ b/src/models/gemma4-assistant.cpp @@ -39,7 +39,6 @@ void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); const int64_t n_embd_backbone = hparams.n_embd_out(); - nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight"), { 2*n_embd_backbone, n_embd }, 0); nextn_post_proj = create_tensor(tn(LLM_TENSOR_NEXTN_POST_PROJ, "weight"), { n_embd, n_embd_backbone }, 0); int rope_freqs_flag = 0; @@ -51,6 +50,10 @@ void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { const int64_t n_embd_head = hparams.n_embd_head_k(i); const int64_t n_ff = hparams.n_ff(i); + if (i == 0) { + nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);