diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8fa207ce..513acd55 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -85,14 +85,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA4, "gemma4" }, { LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" }, { LLM_ARCH_DFLASH_DRAFT, "dflash-draft" }, - { LLM_ARCH_GEMMA4_ASSISTANT,"gemma4_assistant" }, + { LLM_ARCH_GEMMA4_ASSISTANT,"gemma4-assistant" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; llm_arch llm_arch_from_string(const std::string & name) { - //if (name == "gemma4_assistant") { - // return llm_arch_from_string("gemma4_mtp"); - //} for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { return kv.first; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index b041844c..589e281d 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -877,9 +877,9 @@ void llm_load_hparams( ml.get_key(LLM_KV_MTP_CENTROID_COUNT, hparams.mtp_num_centroids, false); ml.get_key(LLM_KV_MTP_CENTROID_TOP_K, hparams.mtp_centroid_top_k, false); } else { - ml.get_key("gemma4_assistant.n_embd_backbone", hparams.mtp_backbone_n_embd); - ml.get_key("gemma4_assistant.n_centroids", hparams.mtp_num_centroids, false); - ml.get_key("gemma4_assistant.centroid_top_k", hparams.mtp_centroid_top_k, false); + ml.get_key("gemma4-assistant.embedding_length_out", hparams.mtp_backbone_n_embd); + ml.get_key("gemma4-assistant.n_centroids", hparams.mtp_num_centroids, false); + ml.get_key("gemma4-assistant.centroid_top_k", hparams.mtp_centroid_top_k, false); } ml.get_key(LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, hparams.mtp_use_ordered_embeddings, false); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 55845fc6..f151d885 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2214,13 +2214,14 @@ bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) { model.mtp_token_ordering = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_TOKEN_ORDERING, "weight"), {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); model.mtp_centroids = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_CENTROIDS, "weight"), {n_embd, hparams.mtp_num_centroids}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { - model.mtp_pre_proj = create_tensor(ctx_output, "mtp.pre_projection.weight", {2*n_backbone, n_embd}, 0); - model.mtp_post_proj = create_tensor(ctx_output, "mtp.post_projection.weight", {n_embd, n_backbone}, 0); + model.mtp_pre_proj = create_tensor(ctx_output, "nextn.pre_projection.weight", {2*n_backbone, n_embd}, 0); + model.mtp_post_proj = create_tensor(ctx_output, "nextn.post_projection.weight", {n_embd, n_backbone}, 0); model.mtp_token_ordering = create_tensor(ctx_output, "mtp.token_ordering.weight", {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); printf("========================== hparams.mtp_num_centroids = %d\n", hparams.mtp_num_centroids); model.mtp_centroids = create_tensor(ctx_output, "mtp.centroids.weight", {n_embd, hparams.mtp_num_centroids}, llama_model_loader::TENSOR_NOT_REQUIRED); } + int rope_flag = 0; for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -2230,7 +2231,11 @@ bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) { const int64_t n_embd_head = hparams.n_embd_head_k(i); const int64_t n_ff_cur = hparams.n_ff(i); - layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + if (!hparams.swa_layers[i]) { + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), { n_rot/2 }, + llama_model_loader::TENSOR_NOT_REQUIRED | rope_flag); + rope_flag = llama_model_loader::TENSOR_DUPLICATED; + } layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head*n_head}, 0);