From c2b8bca80711e33ec51c117ab013b4c5a3509511 Mon Sep 17 00:00:00 2001 From: Samuel Oliveira Alves <107287165+SamuelOliveirads@users.noreply.github.com> Date: Sun, 10 May 2026 01:44:20 -0300 Subject: [PATCH] Add MTP Support for Gemma 4 (#1744) * gemma-mtp: build the arch to load the MTP model * gemma-mtp: fix mtp kv state * gemma-mtp: refactor some functions and create gguf * gemma-mtp: make usable for embeddings models variant * gemma-mtp: fix qwen mtp load in graph split * gemma-mtp: refactor tensor creation and adjust output tensor handling * Gemma 4 MTP: improve tensor handling, and adjust split mode logic --- common/common.cpp | 11 +- common/speculative.cpp | 62 ++++--- common/speculative.h | 3 +- convert_hf_to_gguf.py | 261 ++++++++++++++++++++++++++ examples/server/server-context.cpp | 200 +++++++++++++++----- examples/server/server-context.h | 1 + ggml/src/ggml-backend.cpp | 3 +- gguf-py/gguf/constants.py | 84 +++++++++ gguf-py/gguf/tensor_mapping.py | 45 +++++ include/llama.h | 5 + src/graphs/build_gemma4.cpp | 282 +++++++++++++++++++++++++++++ src/llama-arch.cpp | 5 + src/llama-arch.h | 9 + src/llama-build-context.cpp | 95 ++++++---- src/llama-build-context.h | 5 +- src/llama-context.h | 1 + src/llama-hparams.cpp | 20 ++ src/llama-hparams.h | 7 + src/llama-load-tensors.cpp | 70 ++++++- src/llama-model.cpp | 43 +++++ src/llama-model.h | 5 + src/llama.cpp | 130 +++++++++---- 22 files changed, 1193 insertions(+), 154 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 56ea32a4..bb8ed772 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1101,8 +1101,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; } else if (value == "suffix") { params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX; + } else if (value == "mtp") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + params.has_mtp = true; } else { - throw std::invalid_argument("unknown speculative decoding type without draft model"); + throw std::invalid_argument("unknown speculative decoding type"); } return true; } @@ -2760,7 +2763,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param " per-step save SSM state per draft step in VRAM; no re-decode on rejection\n" " gpu-fallback copy state to GPU buffer; re-decode on rejection\n" " cpu serialise state via llama_state_seq; re-decode on rejection" }); - options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type}); + options.push_back({ "*", "--spec-type Name [none | mtp | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use (default: %d)\n", (int)params.speculative.type}); options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n }); options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m }); @@ -3355,11 +3358,9 @@ std::string fs_get_cache_file(const std::string & filename) { } -// -// Model utils -// struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { llama_init_result iparams; + auto mparams = common_model_params_to_llama(params); llama_model * model = nullptr; diff --git a/common/speculative.cpp b/common/speculative.cpp index 5abb7467..531ce010 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -19,6 +19,9 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 +void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx); +uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); + const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, @@ -154,27 +157,28 @@ struct common_speculative_state_mtp : public common_speculative_state { llama_context * ctx_tgt; llama_context * ctx_mtp = nullptr; common_sampler * smpl; + // For Gemma 4 external MTP assistant: draft positions are held constant + bool constant_draft_positions = false; common_speculative_state_mtp( enum common_speculative_type type, llama_context * ctx_tgt, - const llama_context_params & mtp_cparams) + llama_context * ctx_mtp, + bool constant_draft_positions = false) : common_speculative_state(type) , ctx_tgt(ctx_tgt) + , ctx_mtp(ctx_mtp) + , constant_draft_positions(constant_draft_positions) { - struct common_params_sampling params; - params.samplers_sequence = { + struct common_params_sampling sparams; + sparams.samplers_sequence = { llama_sampler_type::DIST, }; - smpl = common_sampler_init(llama_get_model(ctx_tgt), params); + smpl = common_sampler_init(llama_get_model(ctx_mtp), sparams); + llama_set_mtp_target_context(ctx_mtp, ctx_tgt); - const llama_model * model = llama_get_model(ctx_tgt); - ctx_mtp = llama_init_from_model(const_cast(model), mtp_cparams); - if (ctx_mtp) { - LOG_INF("%s: created MTP context (n_ctx=%d)\n", __func__, llama_n_ctx(ctx_mtp)); - } else { - LOG_ERR("%s: failed to create MTP context\n", __func__); - } + LOG_INF("%s: MTP context ready (n_ctx=%d, constant_draft_positions=%s)\n", __func__, + llama_n_ctx(ctx_mtp), constant_draft_positions ? "true" : "false"); } ~common_speculative_state_mtp() override { @@ -211,7 +215,8 @@ struct common_speculative_state_mtp : public common_speculative_state { params.p_min, id_last, n_past, - seq_id + seq_id, + constant_draft_positions ); } @@ -1029,9 +1034,9 @@ common_speculative * common_speculative_init( // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { - bool has_draft = !params.mparams_dft.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP); + bool has_draft = !params.mparams_dft.path.empty() && !has_mtp; bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -1102,15 +1107,20 @@ common_speculative * common_speculative_init( break; } case COMMON_SPECULATIVE_TYPE_MTP: { - auto mtp_state = std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .mtp_cparams = */ params.cparams_dft - ); - if (!mtp_state->ctx_mtp) { - LOG_ERR("%s: failed to create MTP context\n", __func__); - return nullptr; + llama_context * ctx_mtp = ctx_dft; + if (!ctx_mtp) { + const llama_model * model = llama_get_model(ctx_tgt); + ctx_mtp = llama_init_from_model(const_cast(model), params.cparams_dft); + if (!ctx_mtp) { + LOG_ERR("%s: failed to create MTP context\n", __func__); + return nullptr; + } } - impls.push_back(std::move(mtp_state)); + ctx_dft = nullptr; + + const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp)); + impls.push_back(std::make_unique( + config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions)); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { @@ -1224,7 +1234,7 @@ static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) { static std::unordered_map map; auto & last = map[ctx]; if (last.embd.empty()) { - auto n_embd = llama_model_n_embd(llama_get_model(ctx)); + auto n_embd = llama_mtp_state_n_embd(ctx); last.embd.resize(n_embd); } return last; @@ -1377,7 +1387,8 @@ std::vector mtp_speculative_gen_draft( float p_min, llama_token id_last, int32_t n_past, - llama_seq_id seq_id) { + llama_seq_id seq_id, + bool constant_draft_positions) { llama_tokens drafts; drafts.reserve(n_draft); @@ -1394,7 +1405,7 @@ std::vector mtp_speculative_gen_draft( llama_token current_input_id = id_last; int32_t current_n_past = n_past; - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + const int n_embd = llama_mtp_state_n_embd(ctx); auto & last = mtp_get_last_embd(ctx); int i0 = 0; @@ -1415,7 +1426,8 @@ std::vector mtp_speculative_gen_draft( int n_decode = 0; for (int i = i0; i < n_draft; ++i) { mtp_batch.n_tokens = 0; - common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true); + const int32_t draft_pos = constant_draft_positions ? n_past : current_n_past; + common_batch_add(mtp_batch, current_input_id, draft_pos, {seq_id}, true); ++n_decode; if (llama_decode(ctx, mtp_batch) != 0) { diff --git a/common/speculative.h b/common/speculative.h index 2b3fc0c0..6061d133 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -60,7 +60,8 @@ std::vector mtp_speculative_gen_draft( float p_min, llama_token id_last, int32_t n_past, - llama_seq_id seq_id); + llama_seq_id seq_id, + bool constant_draft_positions = false); void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0a3341fa..0664e5aa 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3175,6 +3175,267 @@ class Gemma2Model(Model): return [(self.map_tensor_name(name), data_torch)] +class Gemma4BaseModel(Model): + model_arch = gguf.MODEL_ARCH.GEMMA4 + + def _text_hparams(self) -> dict[str, Any]: + text_hparams = self.hparams.get("text_config") + if isinstance(text_hparams, dict): + return text_hparams + return self.hparams + + def _arch_name(self) -> str: + return gguf.MODEL_ARCH_NAMES[self.model_arch] + + def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: + text_hparams = self.hparams.get("text_config") + if isinstance(text_hparams, dict): + for key in keys: + if key in text_hparams: + return text_hparams[key] + return super().find_hparam(keys, optional) + + def set_vocab(self): + vocab = gguf.LlamaHfVocab(self.dir_model) + tokens = [] + scores = [] + toktypes = [] + visible_tokens = { + "<|channel>", + "", + "<|tool_call>", + "", + "<|tool_response>", + "", + "<|\"|>", + } + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + text_str = text.decode() + if text_str in visible_tokens: + toktypes.append(gguf.TokenType.USER_DEFINED) + logger.info(f"Token {text_str!r} is set to USER_DEFINED") + else: + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + self.gguf_writer.add_tokenizer_model("gemma4") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + self.gguf_writer.add_add_space_prefix(False) + self.gguf_writer.add_add_bos_token(True) + + +@Model.register("Gemma4ForConditionalGeneration") +class Gemma4Model(Gemma4BaseModel): + model_arch = gguf.MODEL_ARCH.GEMMA4 + + def set_gguf_parameters(self): + hparams = self._text_hparams() + block_count = hparams["num_hidden_layers"] + arch = self._arch_name() + + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + swa_layers = [layer_type == "sliding_attention" for layer_type in hparams["layer_types"]] + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + self.gguf_writer.add_sliding_window_pattern(swa_layers) + + num_kv_shared_layers = hparams.get("num_kv_shared_layers", 0) + self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers) + + n_ff = hparams["intermediate_size"] + if hparams.get("use_double_wide_mlp", False): + first_kv_shared_layer_idx = block_count - num_kv_shared_layers + n_ff_arr = [n_ff if il < first_kv_shared_layer_idx else n_ff * 2 for il in range(block_count)] + self.gguf_writer.add_feed_forward_length(n_ff_arr) + else: + self.gguf_writer.add_feed_forward_length(n_ff) + + expert_intermediate_size = hparams.get("expert_intermediate_size") or hparams.get("moe_intermediate_size") + if expert_intermediate_size is not None: + self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size) + + n_pl_embd = hparams.get("hidden_size_per_layer_input") or 0 + self.gguf_writer.add_embedding_length_per_layer_input(n_pl_embd) + + head_dim_full = int(hparams["global_head_dim"]) + head_dim_swa = int(hparams["head_dim"]) + self.gguf_writer.add_key_length(head_dim_full) + self.gguf_writer.add_value_length(head_dim_full) + self.gguf_writer.add_uint32(f"{arch}.attention.key_length_swa", head_dim_swa) + self.gguf_writer.add_uint32(f"{arch}.attention.value_length_swa", head_dim_swa) + + num_kv_full = hparams.get("num_global_key_value_heads") + num_kv_swa = hparams.get("num_key_value_heads") + if num_kv_full is not None and num_kv_swa is not None: + kv_heads = [num_kv_swa if is_swa else num_kv_full for is_swa in swa_layers] + self.gguf_writer.add_head_count_kv(kv_heads) + elif num_kv_swa is not None: + self.gguf_writer.add_head_count_kv(num_kv_swa) + + rope_parameters = hparams.get("rope_parameters", {}) + rope_full = rope_parameters.get("full_attention", {}) + rope_swa = rope_parameters.get("sliding_attention", {}) + self.gguf_writer.add_rope_dimension_count(head_dim_full) + partial_rotary_factor_swa = float(rope_swa.get("partial_rotary_factor", hparams.get("partial_rotary_factor", 1.0))) + self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", int(head_dim_swa * partial_rotary_factor_swa)) + self.gguf_writer.add_rope_freq_base(float(rope_full.get("rope_theta", 1000000.0))) + self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(rope_swa.get("rope_theta", 10000.0))) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + hparams = self._text_hparams() + rope_params_full = hparams["rope_parameters"]["full_attention"] + assert rope_params_full["rope_type"] == "proportional" + + head_dim_full = int(hparams["global_head_dim"]) + partial_rotary_factor_full = rope_params_full["partial_rotary_factor"] + n_rot_full = int(head_dim_full * partial_rotary_factor_full / 2) + n_unrot_full = int(head_dim_full / 2) - n_rot_full + values = [1.0] * n_rot_full + [1e30] * n_unrot_full + rope_freqs_full = torch.tensor(values, dtype=torch.float32) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith("per_dim_scale") or name.endswith("layer_scalar"): + name = name + ".weight" + + if "language_model." not in name and "rope_freqs" not in name: + return [] + + name = name.replace("language_model.", "") + + if name == "lm_head.weight": + logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") + return [] + + if name.endswith("router.scale"): + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid, ".scale"), data_torch)] + + if ".per_expert_scale" in name: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid, ".scale"), data_torch)] + + if ".experts." in name and not name.endswith(".weight"): + name += ".weight" + + return [(self.map_tensor_name(name), data_torch)] + + +@Model.register("Gemma4AssistantForCausalLM") +class Gemma4AssistantModel(Gemma4BaseModel): + model_arch = gguf.MODEL_ARCH.GEMMA4_MTP + + _root_tensor_map = { + "model.embed_tokens.weight": "token_embd.weight", + "model.norm.weight": "output_norm.weight", + "pre_projection.weight": "mtp_pre_proj.weight", + "post_projection.weight": "mtp_post_proj.weight", + "masked_embedding.centroids.weight": "mtp_centroids.weight", + "masked_embedding.token_ordering": "mtp_token_ordering.weight", + "token_ordering": "mtp_token_ordering.weight", + "token_ordering.weight": "mtp_token_ordering.weight", + "model.token_ordering": "mtp_token_ordering.weight", + "model.token_ordering.weight": "mtp_token_ordering.weight", + "centroids": "mtp_centroids.weight", + "centroids.weight": "mtp_centroids.weight", + "model.centroids": "mtp_centroids.weight", + "model.centroids.weight": "mtp_centroids.weight", + } + + _layer_tensor_map = { + "input_layernorm.weight": "attn_norm.weight", + "self_attn.q_proj.weight": "attn_q.weight", + "self_attn.q_norm.weight": "attn_q_norm.weight", + "self_attn.o_proj.weight": "attn_output.weight", + "post_attention_layernorm.weight": "post_attention_norm.weight", + "pre_feedforward_layernorm.weight": "ffn_norm.weight", + "mlp.gate_proj.weight": "ffn_gate.weight", + "mlp.up_proj.weight": "ffn_up.weight", + "mlp.down_proj.weight": "ffn_down.weight", + "post_feedforward_layernorm.weight": "post_ffw_norm.weight", + "layer_scalar": "layer_output_scale.weight", + "layer_scalar.weight": "layer_output_scale.weight", + } + + def set_gguf_parameters(self): + hparams = self._text_hparams() + arch = self._arch_name() + sliding_pattern = [layer_type == "sliding_attention" for layer_type in hparams["layer_types"]] + + head_dim_swa = int(hparams["head_dim"]) + head_dim_full = int(hparams.get("global_head_dim") or head_dim_swa) + n_kv_swa = int(hparams["num_key_value_heads"]) + n_kv_full = int(hparams.get("num_global_key_value_heads") or n_kv_swa) + n_kv = [n_kv_swa if is_sliding else n_kv_full for is_sliding in sliding_pattern] + + self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"])) + self.gguf_writer.add_embedding_length(int(hparams["hidden_size"])) + self.gguf_writer.add_block_count(int(hparams["num_hidden_layers"])) + self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"])) + self.gguf_writer.add_head_count(int(hparams["num_attention_heads"])) + self.gguf_writer.add_head_count_kv(n_kv) + self.gguf_writer.add_key_length(head_dim_full) + self.gguf_writer.add_value_length(head_dim_full) + self.gguf_writer.add_uint32(f"{arch}.attention.key_length_swa", head_dim_swa) + self.gguf_writer.add_uint32(f"{arch}.attention.value_length_swa", head_dim_swa) + self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"])) + self.gguf_writer.add_sliding_window(int(hparams["sliding_window"])) + self.gguf_writer.add_array(f"{arch}.attention.sliding_window_pattern", sliding_pattern) + self.gguf_writer.add_rope_dimension_count(head_dim_full) + self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim_swa) + + rope_parameters = hparams.get("rope_parameters", {}) + rope_full = rope_parameters.get("full_attention", {}) + rope_swa = rope_parameters.get("sliding_attention", {}) + self.gguf_writer.add_rope_freq_base(float(rope_full.get("rope_theta", 1000000.0))) + self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(rope_swa.get("rope_theta", 10000.0))) + + self.gguf_writer.add_uint32(f"{arch}.backbone_embedding_length", int(self.hparams["backbone_hidden_size"])) + self.gguf_writer.add_bool(f"{arch}.use_ordered_embeddings", bool(self.hparams.get("use_ordered_embeddings", False))) + self.gguf_writer.add_uint32(f"{arch}.centroid_count", int(self.hparams.get("num_centroids", 0))) + self.gguf_writer.add_uint32(f"{arch}.centroid_top_k", int(self.hparams.get("centroid_intermediate_top_k", 0))) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + mapped_name = self._root_tensor_map.get(name) + if mapped_name is not None: + if mapped_name == "mtp_token_ordering.weight": + n_vocab = int(data_torch.shape[0]) + n_centroids = int(self.hparams.get("num_centroids", 2048)) + tokens_per_centroid = n_vocab // n_centroids + inv_ordering = torch.zeros(n_vocab, dtype=torch.int32) + tok_ord_i32 = data_torch.to(dtype=torch.int64) + inv_ordering[tok_ord_i32] = torch.arange(n_vocab, dtype=torch.int32) + token_to_centroid = (inv_ordering // tokens_per_centroid).to(dtype=torch.int32) + return [(mapped_name, token_to_centroid)] + return [(mapped_name, data_torch)] + + prefix = "model.layers." + if not name.startswith(prefix): + raise ValueError(f"Unsupported Gemma 4 assistant tensor: {name}") + + layer_id, suffix = name[len(prefix):].split(".", 1) + mapped_suffix = self._layer_tensor_map.get(suffix) + if mapped_suffix is None: + raise ValueError(f"Unsupported Gemma 4 assistant tensor: {name}") + + return [(f"blk.{layer_id}.{mapped_suffix}", data_torch)] + + @Model.register("Starcoder2ForCausalLM") class StarCoder2Model(Model): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 12b44a8d..dd0f0f4d 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -16,6 +16,8 @@ #include #include +uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); + static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) { if (pos_min == -1) { pos_min = llama_kv_cache_seq_pos_min(ctx, id); @@ -44,6 +46,97 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } +static bool params_use_gemma4_external_mtp(const gpt_params & params_base) { + return params_base.has_mtp && + llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft); +} + +static llama_context * get_slot_mtp_ctx(server_slot & slot, llama_context * ctx) { + llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); + return mtp_ctx ? mtp_ctx : ctx; +} + +static int get_ctx_mtp_n_embd(llama_context * ctx) { + return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0; +} + +static int get_slot_mtp_n_embd(server_slot & slot, llama_context * ctx) { + return get_ctx_mtp_n_embd(get_slot_mtp_ctx(slot, ctx)); +} + +static void cache_slot_mtp_hidden(server_slot & slot, const float * hidden, int n_embd) { + if (hidden == nullptr || n_embd <= 0) { + return; + } + + slot.mtp_hidden_state.assign(hidden, hidden + n_embd); +} + +static void sync_slot_mtp_hidden(server_slot & slot, llama_context * ctx) { + if (!slot.has_mtp || !slot.spec || slot.mtp_hidden_state.empty()) { + return; + } + + const int n_embd = get_slot_mtp_n_embd(slot, ctx); + if (n_embd <= 0 || slot.mtp_hidden_state.size() < (size_t) n_embd) { + return; + } + + const int n_hidden = slot.mtp_hidden_state.size() / n_embd; + llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd); +} + +static void cache_and_sync_slot_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) { + cache_slot_mtp_hidden(slot, hidden, n_embd); + sync_slot_mtp_hidden(slot, ctx); +} + +static void cache_and_sync_slot_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector & rows, int n_embd) { + if (rows.empty() || n_embd <= 0) { + return; + } + + const size_t n_rows = rows.size() / n_embd; + if (n_rows == 0) { + return; + } + + cache_and_sync_slot_mtp_hidden(slot, ctx, rows.data() + (n_rows - 1) * n_embd, n_embd); +} + +static void apply_slot_mtp_accept( + server_slot & slot, + llama_context * ctx, + const std::vector & mtp_hidden_state, + const std::vector & ids, + int32_t mtp_n_past_base, + int n_embd) { + if (!slot.has_mtp || mtp_hidden_state.empty() || n_embd <= 0) { + return; + } + + if (slot.use_gemma4_external_mtp) { + cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_hidden_state, n_embd); + return; + } + + slot.mtp_hidden_state = mtp_hidden_state; + llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data()); + mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, mtp_n_past_base, slot.id); +} + +static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) { + if (!slot.has_mtp || !slot.spec || hidden == nullptr || n_embd <= 0) { + return; + } + + cache_and_sync_slot_mtp_hidden(slot, ctx, hidden, n_embd); +} + +static void set_external_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector & rows, int n_embd) { + cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, rows, n_embd); +} + void server_speculative_checkpoint::clear() { valid = false; per_step_enabled = false; @@ -185,6 +278,7 @@ bool server_context::load_model(const gpt_params& params_) { gpt_params params_dft; params_dft.devices = params_base.speculative.devices; params_dft.model = params_base.speculative.model; + params_dft.main_gpu = params_base.main_gpu; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.rpc_servers = params_base.rpc_servers; params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k; @@ -279,16 +373,22 @@ void server_context::init() { slot.sparams = params_base.sparams; if (params_base.has_mtp) { - if (llama_model_n_nextn_layer(model) > 0) { + const bool has_external_mtp = params_use_gemma4_external_mtp(params_base); + + if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) { params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; params_base.pooling_type = LLAMA_POOLING_TYPE_NONE; - params_base.speculative.cparams_dft = common_context_params_to_llama(params_base); + if (!has_external_mtp) { + params_base.speculative.cparams_dft = common_context_params_to_llama(params_base); + } + params_base.speculative.cparams_dft.mtp = true; params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP; params_base.speculative.cparams_dft.embeddings = true; slot.has_mtp = true; + slot.use_gemma4_external_mtp = has_external_mtp; slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; slot.params.speculative.n_min = 0; slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft; @@ -3276,20 +3376,14 @@ void server_context::add_sampled_tokens() { auto & params_spec = slot.params.speculative; if (slot.has_mtp) { - llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); - llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx; if (!slot.mtp_hidden_state.empty()) { - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); - const int n_hidden = slot.mtp_hidden_state.size() / n_embd; - llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd); + sync_slot_mtp_hidden(slot, ctx); } else { LOG_ERROR("MTP hidden state is empty during speculation", {}); const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1); if (emb_neg1) { - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); - slot.mtp_hidden_state.resize(n_embd); - memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float)); - llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data()); + const int n_embd = get_ctx_mtp_n_embd(ctx); + cache_and_sync_slot_mtp_hidden(slot, ctx, emb_neg1, n_embd); } } } @@ -3857,11 +3951,8 @@ static void restore_speculative_checkpoint( // Update MTP KV cache and hidden state using embeddings collected before checkpoint restore. if (slot.has_mtp && !mtp_hidden_state_pre.empty()) { - slot.mtp_hidden_state = mtp_hidden_state_pre; - llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); - llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; - llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data()); - mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id); + const int n_embd = get_ctx_mtp_n_embd(ctx); + apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd); } SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n", @@ -3895,7 +3986,7 @@ static void restore_speculative_checkpoint( SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret); } if (slot.has_mtp) { - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + const int n_embd = get_ctx_mtp_n_embd(ctx); const int n_accepted = (int)ids.size(); slot.mtp_hidden_state.resize(n_accepted * n_embd); @@ -3906,15 +3997,17 @@ static void restore_speculative_checkpoint( } } - llama_context * mtp_ctx_rej = common_speculative_get_mtp_ctx(slot.spec); - llama_context * mtp_target_rej = mtp_ctx_rej ? mtp_ctx_rej : ctx; - llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data()); - mtp_accept_tokens(mtp_target_rej, ids, slot.spec_ckpt.n_past, slot.id); + if (slot.use_gemma4_external_mtp) { + cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, slot.mtp_hidden_state, n_embd); + } else { + llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data()); + mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, slot.spec_ckpt.n_past, slot.id); - if (n_accepted > 1) { - memmove(slot.mtp_hidden_state.data(), - slot.mtp_hidden_state.data() + (n_accepted - 1) * n_embd, - n_embd * sizeof(float)); + if (n_accepted > 1) { + memmove(slot.mtp_hidden_state.data(), + slot.mtp_hidden_state.data() + (n_accepted - 1) * n_embd, + n_embd * sizeof(float)); + } } slot.mtp_hidden_state.resize(n_embd); } @@ -3972,7 +4065,7 @@ void server_context::speculative_decoding_accept() { if (slot.has_mtp) { mtp_n_past_base = slot.n_past - (slot.drafted.size() + 1); - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + const int n_embd = get_ctx_mtp_n_embd(ctx); if (!ids.empty()) { mtp_hidden_state_pre.resize(ids.size() * n_embd); for (size_t i = 0; i < ids.size(); i++) { @@ -4018,13 +4111,9 @@ void server_context::speculative_decoding_accept() { restore_speculative_checkpoint(slot, ctx, model, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base); } else { if (slot.has_mtp && !mtp_hidden_state_pre.empty()) { - llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec); - llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; - - slot.mtp_hidden_state = std::move(mtp_hidden_state_pre); - llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data()); - mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id); - } + const int n_embd = get_ctx_mtp_n_embd(ctx); + apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd); + } llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); discard_speculative_checkpoint(slot, ctx); } @@ -4396,8 +4485,18 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } bool mtp_warmup_needed = false; + llama_context * batch_mtp_target = nullptr; std::vector batch_mtp_hidden_state; if (params_base.has_mtp) { + for (auto & slot : slots) { + if (slot.spec && slot.has_mtp) { + llama_context * mc = common_speculative_get_mtp_ctx(slot.spec); + if (mc) { + batch_mtp_target = mc; + break; + } + } + } for (auto& slot : slots) { if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) || (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) { @@ -4409,13 +4508,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) { } } if (mtp_warmup_needed) { - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); + llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx; + const int n_embd_src = get_ctx_mtp_n_embd(ctx); + const int n_embd_dst = get_ctx_mtp_n_embd(mtp_target); const int n_toks = batch_view.n_tokens; - batch_mtp_hidden_state.resize(n_toks * n_embd); + batch_mtp_hidden_state.assign(n_toks * n_embd_dst, 0.0f); for (int t = 0; t < n_toks; t++) { const float* emb_t = llama_get_embeddings_ith(ctx, t); if (emb_t) { - memcpy(batch_mtp_hidden_state.data() + t * n_embd, emb_t, n_embd * sizeof(float)); + const int n_copy = std::min(n_embd_src, n_embd_dst); + memcpy(batch_mtp_hidden_state.data() + t * n_embd_dst, emb_t, n_copy * sizeof(float)); } } } @@ -4469,9 +4571,12 @@ void server_context::process_batch_tokens(int32_t & n_batch) { if (params_base.has_mtp && slot.n_decoded == 0) { const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx); if (emb_i) { - const int n_embd = llama_model_n_embd(llama_get_model(ctx)); - slot.mtp_hidden_state.resize(n_embd); - memcpy(slot.mtp_hidden_state.data(), emb_i, n_embd * sizeof(float)); + const int n_embd = get_ctx_mtp_n_embd(ctx); + if (slot.use_gemma4_external_mtp) { + set_external_mtp_hidden(slot, ctx, emb_i, n_embd); + } else { + cache_slot_mtp_hidden(slot, emb_i, n_embd); + } } } @@ -4537,16 +4642,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) { slot.i_batch = -1; } if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) { - llama_context * mtp_ctx = nullptr; - for (auto & slot : slots) { - if (slot.spec && slot.has_mtp) { - llama_context * mc = common_speculative_get_mtp_ctx(slot.spec); - if (mc) { mtp_ctx = mc; break; } + if (params_use_gemma4_external_mtp(params_base)) { + for (auto & slot : slots) { + if (slot.spec && slot.has_mtp && !slot.mtp_hidden_state.empty()) { + sync_slot_mtp_hidden(slot, ctx); + } } + } else { + llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx; + llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data()); + mtp_update_kv_cache(mtp_target, batch_view, true); } - llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx; - llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data()); - mtp_update_kv_cache(mtp_target, batch_view, true); } // speculative decoding - main model sample and accept diff --git a/examples/server/server-context.h b/examples/server/server-context.h index e2ca56b8..f1009ae2 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -171,6 +171,7 @@ struct server_slot { decltype(ctx_sampling->elb_states) elb_prev_states; bool has_mtp = false; + bool use_gemma4_external_mtp = false; std::vector mtp_hidden_state; // saves recurrent state before a speculative batch so it can be restored on rejection diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 97c7481a..805ca315 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1107,7 +1107,8 @@ static bool ggml_is_view_op(enum ggml_op op) { #endif #ifndef GGML_SCHED_MAX_SPLIT_INPUTS -#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC +// Gemma4 with per-layer embeddings and uses up to 32 inputs +#define GGML_SCHED_MAX_SPLIT_INPUTS 32 #endif #ifndef GGML_SCHED_MAX_COPIES diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f4e95d0e..9219f0e5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -234,6 +234,8 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() GEMMA2 = auto() GEMMA3 = auto() + GEMMA4 = auto() + GEMMA4_MTP = auto() STARCODER2 = auto() MAMBA = auto() XVERSE = auto() @@ -282,7 +284,10 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_INP_SHEXP = auto() FFN_NORM = auto() FFN_PRE_NORM = auto() + FFN_PRE_NORM_2 = auto() FFN_POST_NORM = auto() + FFN_POST_NORM_1 = auto() + FFN_POST_NORM_2 = auto() FFN_GATE = auto() FFN_DOWN = auto() FFN_UP = auto() @@ -291,6 +296,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() + FFN_GATE_UP_EXP = auto() FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() @@ -298,6 +304,13 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() + LAYER_OUT_SCALE = auto() + PER_LAYER_TOKEN_EMBD = auto() + PER_LAYER_MODEL_PROJ = auto() + PER_LAYER_INP_GATE = auto() + PER_LAYER_PROJ = auto() + PER_LAYER_PROJ_NORM = auto() + PER_LAYER_POST_NORM = auto() SSM_IN = auto() SSM_CONV1D = auto() SSM_X = auto() @@ -349,6 +362,10 @@ class MODEL_TENSOR(IntEnum): NEXTN_HNORM = auto() # nextn tensors (glm4moe) NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe) NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe) + MTP_PRE_PROJ = auto() + MTP_POST_PROJ = auto() + MTP_TOKEN_ORDERING = auto() + MTP_CENTROIDS = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -383,6 +400,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.GEMMA4: "gemma4", + MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", @@ -434,7 +453,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_PRE_NORM_2: "blk.{bid}.pre_ffw_norm_2", MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", + MODEL_TENSOR.FFN_POST_NORM_1: "blk.{bid}.post_ffw_norm_1", + MODEL_TENSOR.FFN_POST_NORM_2: "blk.{bid}.post_ffw_norm_2", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", @@ -446,8 +468,16 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.LAYER_OUT_SCALE: "blk.{bid}.layer_output_scale", + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", + MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", + MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", + MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", + MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", + MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", @@ -500,6 +530,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm", MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head", MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm", + MODEL_TENSOR.MTP_PRE_PROJ: "mtp_pre_proj", + MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj", + MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering", + MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -962,6 +996,56 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_PRE_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GEMMA4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_PRE_NORM_2, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.FFN_POST_NORM_1, + MODEL_TENSOR.FFN_POST_NORM_2, + MODEL_TENSOR.FFN_GATE_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.LAYER_OUT_SCALE, + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD, + MODEL_TENSOR.PER_LAYER_MODEL_PROJ, + MODEL_TENSOR.PER_LAYER_INP_GATE, + MODEL_TENSOR.PER_LAYER_PROJ, + MODEL_TENSOR.PER_LAYER_PROJ_NORM, + MODEL_TENSOR.PER_LAYER_POST_NORM, + ], + MODEL_ARCH.GEMMA4_MTP: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.LAYER_OUT_SCALE, + MODEL_TENSOR.MTP_PRE_PROJ, + MODEL_TENSOR.MTP_POST_PROJ, + MODEL_TENSOR.MTP_TOKEN_ORDERING, + MODEL_TENSOR.MTP_CENTROIDS, + ], MODEL_ARCH.STARCODER2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 367f47be..63c45f4b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -245,12 +245,24 @@ class TensorNameMap: "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 ), + MODEL_TENSOR.FFN_PRE_NORM_2: ( + "model.layers.{bid}.pre_feedforward_layernorm_2", # gemma4 + ), + # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 "model.layers.{bid}.post_moe_norm", # grok-2 ), + MODEL_TENSOR.FFN_POST_NORM_1: ( + "model.layers.{bid}.post_feedforward_layernorm_1", # gemma4 + ), + + MODEL_TENSOR.FFN_POST_NORM_2: ( + "model.layers.{bid}.post_feedforward_layernorm_2", # gemma4 + ), + MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral "model.layers.{bid}.block_sparse_moe.gate", # mixtral @@ -305,6 +317,11 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged) ), + MODEL_TENSOR.FFN_GATE_UP_EXP: ( + "model.layers.{bid}.mlp.experts.gate_up_proj", # gemma4 + "model.layers.{bid}.experts.gate_up_proj", # gemma4 + ), + MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2 @@ -413,6 +430,34 @@ class TensorNameMap: "model.layers.{bid}.final_layernorm", # bailingmoe2 ), + MODEL_TENSOR.LAYER_OUT_SCALE: ( + "model.layers.{bid}.layer_scalar", # gemma4 + ), + + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( + "model.embed_tokens_per_layer", # gemma4 + ), + + MODEL_TENSOR.PER_LAYER_MODEL_PROJ: ( + "model.per_layer_model_projection", # gemma4 + ), + + MODEL_TENSOR.PER_LAYER_PROJ_NORM: ( + "model.per_layer_projection_norm", # gemma4 + ), + + MODEL_TENSOR.PER_LAYER_INP_GATE: ( + "model.layers.{bid}.per_layer_input_gate", # gemma4 + ), + + MODEL_TENSOR.PER_LAYER_PROJ: ( + "model.layers.{bid}.per_layer_projection", # gemma4 + ), + + MODEL_TENSOR.PER_LAYER_POST_NORM: ( + "model.layers.{bid}.post_per_layer_input_norm", # gemma4 + ), + MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", diff --git a/include/llama.h b/include/llama.h index ac0a275b..42539c70 100644 --- a/include/llama.h +++ b/include/llama.h @@ -685,6 +685,11 @@ extern "C" { LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model); + // Returns true if the model is a Gemma 4 MTP assistant (external frozen-KV speculative drafter) + LLAMA_API bool llama_model_is_gemma4_mtp_assistant(const struct llama_model * model); + + LLAMA_API bool llama_is_gemma4_mtp_file(const char * path); + LLAMA_API bool llama_model_is_split_mode_graph(const struct llama_model * model); // Returns 0 on success diff --git a/src/graphs/build_gemma4.cpp b/src/graphs/build_gemma4.cpp index 4a9a39ae..fa254bb7 100644 --- a/src/graphs/build_gemma4.cpp +++ b/src/graphs/build_gemma4.cpp @@ -2,6 +2,128 @@ #include "../llama-model.h" #include "../llama-context.h" +static int gemma4_mtp_target_kv_layer(const llama_hparams & mtp_hparams, const llama_hparams & target_hparams, int mtp_il) { + GGML_ASSERT(mtp_il >= 0 && mtp_il < (int) mtp_hparams.n_layer); + + const bool is_sliding = mtp_hparams.swa_layers[mtp_il] != 0; + const int target_n_kv_layer = target_hparams.n_layer_kv_from_start > 0 + ? std::min((int) target_hparams.n_layer, target_hparams.n_layer_kv_from_start) + : (int) target_hparams.n_layer; + + int target_il = target_n_kv_layer - 1; + for (; target_il >= 0; --target_il) { + if ((target_hparams.swa_layers[target_il] != 0) == is_sliding) { + break; + } + } + + GGML_ASSERT(target_il >= 0 && "Gemma4 MTP could not find a matching target KV layer"); + return target_il; +} + +static void gemma4_mtp_prepare_frozen_kv_views( + ggml_context * ctx0, + llama_context & lctx, + const llama_kv_cache & target_kv, + int assistant_il, + int target_il, + int32_t target_n_kv, + ggml_tensor ** frozen_k, + ggml_tensor ** frozen_v, + const llm_build_cb & cb) { + if (*frozen_k || *frozen_v) { + GGML_ASSERT(*frozen_k && *frozen_v); + return; + } + + if (!lctx.cparams.flash_attn) { + return; + } + + GGML_ASSERT(target_il >= 0 && target_il < (int) target_kv.k_l.size() && target_il < (int) target_kv.v_l.size()); + + ggml_tensor * k_cache = target_kv.k_l[target_il]; + ggml_tensor * v_cache = target_kv.v_l[target_il]; + if (!k_cache || !v_cache || !k_cache->extra || !v_cache->extra) { + return; + } + + auto * split_k = (ggml_split_tensor_t *) k_cache->extra; + auto * split_v = (ggml_split_tensor_t *) v_cache->extra; + + GGML_ASSERT(split_k && split_v); + GGML_ASSERT(split_k->n_device == split_v->n_device); + + const llama_hparams & assistant_hparams = lctx.model.hparams; + const int64_t n_embd_head_k = assistant_hparams.n_embd_head_k(assistant_il); + const int64_t n_embd_head_v = assistant_hparams.n_embd_head_v(assistant_il); + + std::vector k_parts; + std::vector v_parts; + k_parts.reserve(split_k->n_device); + v_parts.reserve(split_v->n_device); + + for (int id = 0; id < split_k->n_device; ++id) { + ggml_tensor * split_kl = split_k->splits[id]; + ggml_tensor * split_vl = split_v->splits[id]; + + GGML_ASSERT((split_kl && split_vl) || (!split_kl && !split_vl)); + if (!split_kl) { + continue; + } + + GGML_ASSERT(target_kv.size > 0); + GGML_ASSERT(split_kl->ne[1] % target_kv.size == 0); + + const int64_t split_n_head_kv = split_kl->ne[1] / target_kv.size; + + ggml_tensor * k_part = ggml_view_3d(ctx0, split_kl, + n_embd_head_k, target_n_kv, split_n_head_kv, + ggml_row_size(split_kl->type, n_embd_head_k) * split_n_head_kv, + ggml_row_size(split_kl->type, n_embd_head_k), + 0); + if (k_part->type != GGML_TYPE_F32) { + k_part = ggml_cast(ctx0, k_part, GGML_TYPE_F32); + } + cb(k_part, "mtp_frozen_k_split", 1000 * (assistant_il + 1) + id); + + ggml_tensor * v_part = ggml_view_3d(ctx0, split_vl, + n_embd_head_v, target_n_kv, split_n_head_kv, + ggml_row_size(split_vl->type, split_n_head_kv * n_embd_head_v), + ggml_row_size(split_vl->type, n_embd_head_v), + 0); + if (v_part->type != GGML_TYPE_F32) { + v_part = ggml_cast(ctx0, v_part, GGML_TYPE_F32); + } + cb(v_part, "mtp_frozen_v_split", 1000 * (assistant_il + 1) + id); + + k_parts.push_back(k_part); + v_parts.push_back(v_part); + } + + GGML_ASSERT(!k_parts.empty() && k_parts.size() == v_parts.size()); + + ggml_tensor * k_full = k_parts[0]; + ggml_tensor * v_full = v_parts[0]; + for (size_t i = 1; i < k_parts.size(); ++i) { + k_full = ggml_concat(ctx0, k_full, k_parts[i], 2); + v_full = ggml_concat(ctx0, v_full, v_parts[i], 2); + } + + if (k_full->type != GGML_TYPE_F16) { + k_full = ggml_cast(ctx0, k_full, GGML_TYPE_F16); + } + if (v_full->type != GGML_TYPE_F16) { + v_full = ggml_cast(ctx0, v_full, GGML_TYPE_F16); + } + + cb(k_full, "mtp_frozen_k", assistant_il); + cb(v_full, "mtp_frozen_v", assistant_il); + + *frozen_k = k_full; + *frozen_v = v_full; +} + static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_context & lctx, ggml_context * ctx0, ggml_tensor * inpL, ggml_tensor * inp_pos, ggml_tensor * inp_out_ids, ggml_tensor * KQ_mask, ggml_tensor * KQ_mask_swa, int n_tokens, const llm_build_cb & cb) { @@ -363,6 +485,160 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_ return gf; } +ggml_cgraph * llm_build_context::build_gemma4_mtp() { + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_vocab = hparams.n_vocab; + const int64_t n_backbone = hparams.mtp_backbone_n_embd; + const int32_t n_layer = hparams.n_layer; + const bool has_target_ctx = lctx.mtp_target_ctx != nullptr; + + GGML_ASSERT(n_backbone > 0); + + lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch.n_tokens); + cb(lctx.inp_tokens, "inp_tokens", -1); + ggml_set_input(lctx.inp_tokens); + + ggml_tensor * hidden_state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_backbone, n_tokens); + ggml_set_name(hidden_state, "inp_mtp_states"); + ggml_set_input(hidden_state); + lctx.inp_mtp_states = hidden_state; + + if (!has_target_ctx || !batch.token) { + ggml_tensor * cur = ggml_view_2d(ctx0, hidden_state, n_embd, n_tokens, + ggml_row_size(hidden_state->type, n_backbone), 0); + cb(cur, "mtp_init_hidden_view", -1); + + ggml_tensor * mtp_embd = ggml_dup(ctx0, hidden_state); + cb(mtp_embd, "result_mtp_embd", -1); + ggml_build_forward_expand(gf, mtp_embd); + + ggml_tensor * logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb); + cb(logits, "result_output", -1); + ggml_build_forward_expand(gf, logits); + + GGML_UNUSED(n_vocab); + return gf; + } + + const llama_model & target_model = lctx.mtp_target_ctx->model; + const llama_hparams & target_hparams = target_model.hparams; + const llama_cparams & target_cparams = lctx.mtp_target_ctx->cparams; + const llama_kv_cache & target_kv = lctx.mtp_target_ctx->kv_self; + + GGML_ASSERT(n_tokens <= target_kv.n); + + ggml_tensor * inp_pos = build_inp_pos(); + + ggml_tensor * token_embd = ggml_get_rows(ctx0, target_model.tok_embd, lctx.inp_tokens); + cb(token_embd, "inp_embd_target", -1); + token_embd = ggml_scale(ctx0, token_embd, std::sqrt(float(n_backbone))); + cb(token_embd, "inp_embd_scaled", -1); + + ggml_tensor * cur = ggml_concat(ctx0, token_embd, hidden_state, 0); + cb(cur, "inp_mtp_combined", -1); + cur = llm_build_lora_mm(lctx, ctx0, model.mtp_pre_proj, cur); + cb(cur, "mtp_pre_proj", -1); + + const int32_t target_n_kv = target_kv.n; + const int32_t target_kv_head = target_kv.head; + + ggml_tensor * KQ_mask = nullptr; + ggml_tensor * KQ_mask_swa = nullptr; + ggml_tensor * frozen_k_swa = nullptr; + ggml_tensor * frozen_v_swa = nullptr; + ggml_tensor * frozen_k_full = nullptr; + ggml_tensor * frozen_v_full = nullptr; + { + const int64_t n_mask_tokens = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32, target_n_kv, n_mask_tokens); + cb(lctx.inp_KQ_mask, "KQ_mask", -1); + ggml_set_input(lctx.inp_KQ_mask); + KQ_mask = lctx.inp_KQ_mask; + + if (target_hparams.n_swa > 0) { + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32, target_n_kv, n_mask_tokens); + cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(lctx.inp_KQ_mask_swa); + KQ_mask_swa = lctx.inp_KQ_mask_swa; + } + } + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpL = cur; + + const bool is_sliding = hparams.swa_layers[il] ? true : false; + const float freq_base_l = is_sliding ? target_hparams.rope_freq_base_train_swa : target_cparams.rope_freq_base; + const float freq_scale_l = is_sliding ? target_hparams.rope_freq_scale_train_swa : target_cparams.rope_freq_scale; + const int n_rot_l = is_sliding ? target_hparams.n_rot_swa : target_hparams.n_rot; + const int n_swa = is_sliding ? target_hparams.n_swa : 0; + const int n_embd_head = hparams.n_embd_head_k(il); + const int n_head = hparams.n_head(il); + ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_rope", il); + + const int target_il = gemma4_mtp_target_kv_layer(hparams, target_hparams, il); + ggml_tensor *& frozen_k = is_sliding ? frozen_k_swa : frozen_k_full; + ggml_tensor *& frozen_v = is_sliding ? frozen_v_swa : frozen_v_full; + gemma4_mtp_prepare_frozen_kv_views(ctx0, lctx, target_kv, il, target_il, target_n_kv, &frozen_k, &frozen_v, cb); + cur = llm_build_kv(ctx0, lctx, target_kv, gf, model.layers[il].wo, model.layers[il].bo, + nullptr, nullptr, Qcur, KQ_mask_l, n_tokens, target_kv_head, target_n_kv, hparams.f_attention_scale, cb, il, nullptr, n_swa, target_il, + &frozen_k, &frozen_v); + + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "attn_out", il); + + ggml_tensor * ffn = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, cb, il, gf, true, false, nullptr, model.layers[il].ffn_post_norm); + cb(ffn, "ffn_out", il); + + cur = ffn; + if (model.layers[il].out_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + } + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + } + + ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur); + cb(mtp_embd, "result_mtp_embd", -1); + ggml_build_forward_expand(gf, mtp_embd); + + ggml_tensor * logits; + // E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but + // not required for correct inference — the full-vocab matmul against the tied output + // weight still yields valid per-token logits. + { + logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb); + cb(logits, "result_output", -1); + } + ggml_build_forward_expand(gf, logits); + + GGML_UNUSED(n_embd); + GGML_UNUSED(n_vocab); + + return gf; +} + static ggml_tensor * gemma4_project_per_layer_inputs(ggml_context * ctx0, const llama_model & model, const llm_build_cb & cb, int n_embd, int n_embd_per_layer, int n_layer, int n_tokens, ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { @@ -614,6 +890,12 @@ ggml_cgraph * llm_build_context::build_gemma4() { cur = inpL; + if (cparams.mtp) { + ggml_tensor * mtp_embd = ggml_dup(ctx0, cur); + cb(mtp_embd, "result_mtp_embd", -1); + ggml_build_forward_expand(gf, mtp_embd); + } + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 74568260..7e2bb4c4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -78,6 +78,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -140,6 +141,10 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SWIGLU_LIMITS_SHARED, "%s.swiglu_limits_shared" }, { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, + { LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, "%s.backbone_embedding_length" }, + { LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" }, + { LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" }, + { LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 1c5a6d99..5a148ad7 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -77,6 +77,7 @@ enum llm_arch { LLM_ARCH_GLM_DSA, LLM_ARCH_MISTRAL4, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_MTP, LLM_ARCH_UNKNOWN, }; @@ -133,6 +134,10 @@ enum llm_kv { LLM_KV_SWIGLU_CLAMP_EXP, LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_EMBEDDING_LENGTH_PER_LAYER, + LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, + LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, + LLM_KV_MTP_CENTROID_COUNT, + LLM_KV_MTP_CENTROID_TOP_K, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -358,6 +363,10 @@ enum llm_tensor { LLM_TENSOR_FFN_PRE_NORM_2, // 105 LLM_TENSOR_FFN_POST_NORM_1, LLM_TENSOR_FFN_POST_NORM_2, + LLM_TENSOR_MTP_PRE_PROJ, + LLM_TENSOR_MTP_POST_PROJ, + LLM_TENSOR_MTP_TOKEN_ORDERING, + LLM_TENSOR_MTP_CENTROIDS, LLM_TENSOR_UNKNOWN, }; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 9a3c245f..0ab5b689 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1537,37 +1537,46 @@ static ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr, int n_swa = 0) { + ggml_tensor * sinks = nullptr, int n_swa = 0, int kv_il = -1, + ggml_tensor ** k_cache_view = nullptr, ggml_tensor ** v_cache_view = nullptr) { const llama_model & model = lctx.model; const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; - const int64_t n_ctx = cparams.n_ctx; + const int64_t n_ctx = kv.size; const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_head_k = hparams.n_embd_head_k(il); //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_head_v = hparams.n_embd_head_v(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int kv_layer = kv_il >= 0 ? kv_il : il; struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); cb(q, "q", il); - auto k_cache = lctx.model.hparams.has_kv(il) ? kv.k_l[il] + auto k_cache = kv_il >= 0 ? kv.k_l[kv_layer] + : lctx.model.hparams.has_kv(il) ? kv.k_l[il] : lctx.model.hparams.swa_layers[il] ? kv.k_l[hparams.n_layer_kv_from_start-2] : kv.k_l[hparams.n_layer_kv_from_start-1]; - auto v_cache = lctx.model.hparams.has_kv(il) ? kv.v_l[il] + auto v_cache = kv_il >= 0 ? kv.v_l[kv_layer] + : lctx.model.hparams.has_kv(il) ? kv.v_l[il] : lctx.model.hparams.swa_layers[il] ? kv.v_l[hparams.n_layer_kv_from_start-2] : kv.v_l[hparams.n_layer_kv_from_start-1]; GGML_ASSERT(k_cache != nullptr && "k_cache is null in llm_build_kqv"); GGML_ASSERT(v_cache != nullptr && "v_cache is null in llm_build_kqv"); - struct ggml_tensor * k = - ggml_view_3d(ctx, k_cache, - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(k_cache->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa), - ggml_row_size(k_cache->type, n_embd_head_k), - 0); - cb(k, "k", il); + struct ggml_tensor * k = k_cache_view ? *k_cache_view : nullptr; + if (!k) { + k = ggml_view_3d(ctx, k_cache, + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(k_cache->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa), + ggml_row_size(k_cache->type, n_embd_head_k), + 0); + if (k_cache_view) { + *k_cache_view = k; + } + cb(k, "k", il); + } #ifdef GGML_USE_VULKAN constexpr bool use_f32_precision = true; @@ -1594,13 +1603,18 @@ static ggml_tensor * llm_build_kqv( GGML_UNUSED(n_ctx); // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx, v_cache, - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(v_cache->type, n_embd_v_gqa), - ggml_row_size(v_cache->type, n_embd_head_v), - 0); - cb(v, "v", il); + struct ggml_tensor * v = v_cache_view ? *v_cache_view : nullptr; + if (!v) { + v = ggml_view_3d(ctx, v_cache, + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(v_cache->type, n_embd_v_gqa), + ggml_row_size(v_cache->type, n_embd_head_v), + 0); + if (v_cache_view) { + *v_cache_view = v; + } + cb(v, "v", il); + } cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); @@ -1626,22 +1640,27 @@ static ggml_tensor * llm_build_kqv( } else { // split cached v into n_head heads - struct ggml_tensor * v; - if (kv.v_trans) { - v = ggml_view_3d(ctx, v_cache, - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(v_cache)*n_ctx, - ggml_element_size(v_cache)*n_ctx*n_embd_head_v, - 0); - } else { - v = ggml_view_3d(ctx, v_cache, - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(v_cache->type, n_embd_v_gqa), - ggml_row_size(v_cache->type, n_embd_head_v), - 0); - v = ggml_cont(ctx, ggml_transpose(ctx, v)); + struct ggml_tensor * v = v_cache_view ? *v_cache_view : nullptr; + if (!v) { + if (kv.v_trans) { + v = ggml_view_3d(ctx, v_cache, + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(v_cache)*n_ctx, + ggml_element_size(v_cache)*n_ctx*n_embd_head_v, + 0); + } else { + v = ggml_view_3d(ctx, v_cache, + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(v_cache->type, n_embd_v_gqa), + ggml_row_size(v_cache->type, n_embd_head_v), + 0); + v = ggml_cont(ctx, ggml_transpose(ctx, v)); + } + if (v_cache_view) { + *v_cache_view = v; + } + cb(v, "v", il); } - cb(v, "v", il); auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) { @@ -1775,7 +1794,8 @@ ggml_tensor * llm_build_context::llm_build_kv( int32_t kv_head, int32_t n_kv, float kq_scale, - const llm_build_cb & cb, int il, ggml_tensor * sinks, int n_swa) { + const llm_build_cb & cb, int il, ggml_tensor * sinks, int n_swa, int kv_il, + ggml_tensor ** k_cache_view, ggml_tensor ** v_cache_view) { const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -1805,7 +1825,8 @@ ggml_tensor * llm_build_context::llm_build_kv( llm_build_kv_store(lctx, ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); } - auto cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa); + auto cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa, kv_il, + k_cache_view, v_cache_view); cb(cur, "kqv_out", il); return cur; @@ -2332,6 +2353,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_gemma4(); } break; + case LLM_ARCH_GEMMA4_MTP: + { + result = llm.build_gemma4_mtp(); + } break; case LLM_ARCH_STARCODER2: { result = llm.build_starcoder2(); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 375a203a..361f9710 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -240,6 +240,8 @@ struct llm_build_context { ggml_cgraph * build_gemma4(); + ggml_cgraph * build_gemma4_mtp(); + ggml_cgraph * build_starcoder2(); ggml_cgraph * build_mamba(); @@ -339,7 +341,8 @@ struct llm_build_context { int32_t kv_head, int32_t n_kv, float kq_scale, - const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0); + const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0, int kv_il = -1, + ggml_tensor ** k_cache_view = nullptr, ggml_tensor ** v_cache_view = nullptr); static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx, ggml_tensor * ffn_norm, ggml_tensor * cur, diff --git a/src/llama-context.h b/src/llama-context.h index 075ec991..6554f562 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -199,6 +199,7 @@ struct llama_context { struct llama_cparams cparams; struct llama_sampling sampling; struct llama_kv_cache kv_self; + struct llama_context * mtp_target_ctx = nullptr; struct llama_control_vector cvec; std::vector scale_data; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 530e105e..081aceb9 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -754,6 +754,26 @@ void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GEMMA4_MTP: + { + ml.get_key(LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, hparams.mtp_backbone_n_embd); + ml.get_key(LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, hparams.mtp_use_ordered_embeddings, false); + ml.get_key(LLM_KV_MTP_CENTROID_COUNT, hparams.mtp_num_centroids, false); + ml.get_key(LLM_KV_MTP_CENTROID_TOP_K, hparams.mtp_centroid_top_k, false); + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + hparams.n_layer_kv_from_start = hparams.n_layer; + hparams.f_attention_scale = 1.0f; + + switch (hparams.mtp_backbone_n_embd) { + case 5376: model.type = e_model::MODEL_32B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_STARCODER2: { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 125a6b66..77ea6f13 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -134,6 +134,12 @@ struct llama_hparams { // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; + // gemma4 separate assistant MTP + uint32_t mtp_backbone_n_embd = 0; + bool mtp_use_ordered_embeddings = false; + uint32_t mtp_num_centroids = 0; + uint32_t mtp_centroid_top_k = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; @@ -152,6 +158,7 @@ struct llama_hparams { if (this->n_vocab != other.n_vocab) return true; if (this->n_ctx_train != other.n_ctx_train) return true; if (this->n_embd != other.n_embd) return true; + if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true; if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; if (this->n_swa != other.n_swa) return true; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index af847e4e..4f4edae4 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -95,6 +95,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_gemma4_tensors(const LLM_TN & tn); + bool create_gemma4_mtp_tensors(const LLM_TN & tn); + bool create_starcoder2_tensors(const LLM_TN & tn); bool create_mamba_tensors(const LLM_TN & tn); @@ -2016,6 +2018,7 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) { const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; const int64_t n_ff_exp = hparams.n_ff_exp; + const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN; if (n_embd_head_k != n_embd_head_v) { throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); @@ -2043,7 +2046,8 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) { int rope_freqs_flag = 0; for (int i = 0; i < n_layer; ++i) { - ggml_context * ctx_split = ctx_for_layer_split(i); + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_layer; auto & layer = model.layers[i]; const int64_t n_head = hparams.n_head(i); const int64_t n_embd_head = hparams.n_embd_head_k(i); @@ -2110,6 +2114,53 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + + const int64_t n_backbone = hparams.mtp_backbone_n_embd; + const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN; + if (n_backbone <= 0) { + throw std::runtime_error("Gemma 4 MTP assistant requires backbone_embedding_length metadata"); + } + + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + model.mtp_pre_proj = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_PRE_PROJ, "weight"), {2*n_backbone, n_embd}, 0); + model.mtp_post_proj = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_POST_PROJ, "weight"), {n_embd, n_backbone}, 0); + + model.mtp_token_ordering = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_TOKEN_ORDERING, "weight"), {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.mtp_centroids = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_CENTROIDS, "weight"), {n_embd, hparams.mtp_num_centroids}, llama_model_loader::TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_layer; + auto & layer = model.layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head*n_head}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head*n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.out_scale = create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + + return use_mmap_buffer; +} + bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) { LOADING_PRELUDE @@ -4071,6 +4122,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_gemma_tensors(tn, 3); break; case LLM_ARCH_GEMMA4: use_mmap_buffer = create_gemma4_tensors(tn); break; + case LLM_ARCH_GEMMA4_MTP: + use_mmap_buffer = create_gemma4_mtp_tensors(tn); break; case LLM_ARCH_STARCODER2: use_mmap_buffer = create_starcoder2_tensors(tn); break; case LLM_ARCH_MAMBA: @@ -4140,14 +4193,15 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer &= !has_buft_overrides; } - if (model.arch == LLM_ARCH_GEMMA4 && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) { - bool supported = true; - if (model.tok_embd_per_layer) { - supported = false; - } - if (!supported) { + { + const bool unsupported = + (model.arch == LLM_ARCH_GEMMA4_MTP) || + (model.arch == LLM_ARCH_GEMMA4 && model.tok_embd_per_layer); + if (unsupported && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) { LLAMA_LOG_WARN("\n=========================================================\n"); - LLAMA_LOG_WARN("Split mode 'graph' is not supported for this Gemma4 variant\n"); + LLAMA_LOG_WARN("Split mode 'graph' is not supported for %s\n", + model.arch == LLM_ARCH_GEMMA4_MTP ? "Gemma 4 MTP assistant" + : "this Gemma4 variant"); LLAMA_LOG_WARN(" => changing split mode to 'layer'\n"); LLAMA_LOG_WARN("===========================================================\n\n"); model.split_mode = LLAMA_SPLIT_MODE_LAYER; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8cc7f6ae..fef0069d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -803,6 +803,28 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" }, }, }, + { + LLM_ARCH_GEMMA4_MTP, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" }, + { LLM_TENSOR_MTP_PRE_PROJ, "mtp_pre_proj" }, + { LLM_TENSOR_MTP_POST_PROJ, "mtp_post_proj" }, + { LLM_TENSOR_MTP_TOKEN_ORDERING, "mtp_token_ordering" }, + { LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1881,6 +1903,27 @@ bool llama_model_has_recurrent(const llama_model * model) { return llm_arch_is_hybrid(model->arch) || llm_arch_is_recurrent(model->arch); } +bool llama_model_is_gemma4_mtp_assistant(const llama_model * model) { + return model && model->arch == LLM_ARCH_GEMMA4_MTP; +} + +bool llama_is_gemma4_mtp_file(const char * path) { + if (!path || !*path) return false; + struct gguf_init_params params = { /*.no_alloc =*/ true, /*.ctx =*/ nullptr }; + struct gguf_context * ctx = gguf_init_from_file(path, params); + if (!ctx) return false; + bool result = false; + const int key_id = gguf_find_key(ctx, "general.architecture"); + if (key_id >= 0) { + const char * arch = gguf_get_val_str(ctx, key_id); + if (arch && strcmp(arch, "gemma4_mtp") == 0) { + result = true; + } + } + gguf_free(ctx); + return result; +} + bool llama_model_is_split_mode_graph(const struct llama_model * model) { return model && (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN); } diff --git a/src/llama-model.h b/src/llama-model.h index c3dce07f..decdbb2b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -405,6 +405,11 @@ struct llama_model { struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + struct ggml_tensor * mtp_pre_proj = nullptr; + struct ggml_tensor * mtp_post_proj = nullptr; + struct ggml_tensor * mtp_token_ordering = nullptr; + struct ggml_tensor * mtp_centroids = nullptr; + struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; struct ggml_tensor * output; diff --git a/src/llama.cpp b/src/llama.cpp index 51619cea..8edacfb0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -25,6 +25,9 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx); +void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx); + // TODO: fix these includes #include "iqk/iqk_quantize.h" #include "iqk/iqk_cpu_ops.h" @@ -562,6 +565,7 @@ void llama_context::reset_scheduler() { bool llama_context::can_reuse_graph(const llama_batch & u_batch) { if (!cparams.graph_reuse) return false; if (kv_self.save_per_step_ssm) return false; + if (model.arch == LLM_ARCH_GEMMA4_MTP && mtp_target_ctx != nullptr) return false; auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get(); if (!the_prev || !the_prev->graph) return false; //if (u_batch.n_tokens > 1) return false; @@ -810,6 +814,9 @@ static bool llama_kv_cache_init( const bool is_mtp_tail = qwen_mtp && i >= n_mtp_first; if (split_cache && !is_mtp_tail) { buft_layer_count[model.buft_layer[i].buft_matrix]++; + if (model.buft_layer[i].buft != model.buft_layer[i].buft_matrix) { + buft_layer_count[model.buft_layer[i].buft]++; + } } else { buft_layer_count[model.buft_layer[i].buft]++; } @@ -2519,6 +2526,10 @@ static std::pair, double> get_layer_sizes(const llama_model_ if (name == "output_norm.weight") { continue; } + if (name == "mtp_pre_proj.weight" || name == "mtp_post_proj.weight" || + name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") { + continue; + } auto pos = name.find("blk."); if (pos != 0) { LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str()); @@ -2706,7 +2717,19 @@ static bool llm_load_tensors( auto & hparams = model.hparams; if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) { - if (!is_model_split_supported(model)) { + const bool unsupported_gemma_split = + model.arch == LLM_ARCH_GEMMA4_MTP || + (model.arch == LLM_ARCH_GEMMA4 && hparams.n_embd_per_layer > 0); + + if (unsupported_gemma_split) { + LLAMA_LOG_WARN("\n=========================================================\n"); + LLAMA_LOG_WARN("Split mode 'graph' is not supported for %s\n", + model.arch == LLM_ARCH_GEMMA4_MTP ? "Gemma 4 MTP assistant" + : "this Gemma4 variant"); + LLAMA_LOG_WARN(" => changing split mode to 'layer'\n"); + LLAMA_LOG_WARN("===========================================================\n\n"); + split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (!is_model_split_supported(model)) { LLAMA_LOG_WARN("\n=======================================================\n"); LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n"); LLAMA_LOG_WARN(" => changing split mode to 'layer'\n"); @@ -3028,6 +3051,20 @@ static bool llm_load_tensors( } } } + if (model.arch == LLM_ARCH_GEMMA4_MTP && split_mode == LLAMA_SPLIT_MODE_LAYER && device_count > 0 && n_gpu_layers > 0) { + const int mtp_device = std::clamp(main_gpu, 0, device_count - 1); + + LLAMA_LOG_INFO("%s: Gemma 4 MTP assistant forcing layer placement to GPU %d under layer split\n", + __func__, mtp_device); + + for (int i = i_gpu_start; i < n_layer; ++i) { + model.default_layer_device[i] = mtp_device; + } + if (n_gpu_layers > n_layer) { + model.default_layer_device[n_layer] = mtp_device; + } + } + // assign the repeating layers to the devices according to the splits if (split_mode == LLAMA_SPLIT_MODE_LAYER) { for (int i = i_gpu_start; i < n_layer; ++i) { @@ -3609,7 +3646,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { #endif // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { - const int64_t n_kv = kv_self.n; + const llama_kv_cache & mask_kv_self = + (lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr) + ? lctx.mtp_target_ctx->kv_self + : kv_self; + const int64_t n_kv = mask_kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -3636,21 +3677,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) { + auto noalibi_f16 = [&mask_kv_self, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) { ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); ggml_half h_zero = ggml_fp32_to_fp16(0.f); for (int i = first; i < last; ++i) { - ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero; + ggml_half h = !mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos ? h_inf : h_zero; if (data_f16) data_f16[j*n_kv + i] = h; if (data_swa_f16) { if (h != h_inf) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; - if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { h = h_inf; } } else { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { h = h_inf; } } @@ -3663,7 +3704,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (n_kv >= 1024 && n_tokens >= 32) { int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); int npt = (n_kv + n_thread - 1)/n_thread; - auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) { + auto compute = [&batch, &mask_kv_self, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) { int first = ith * npt; int last = std::min(int(n_kv), first + npt); if (last <= first) return; @@ -3678,11 +3719,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = first; i < last; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos) { f = -INFINITY; } else { if (hparams.use_alibi) { - f = -std::abs(lctx.kv_self.cells[i].pos - pos); + f = -std::abs(mask_kv_self.cells[i].pos - pos); } else { f = 0.0f; } @@ -3700,11 +3741,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (f > -INFINITY) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; - if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { f = -INFINITY; } } else { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { f = -INFINITY; } } @@ -3759,11 +3800,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos) { f = -INFINITY; } else { if (hparams.use_alibi) { - f = -std::abs(lctx.kv_self.cells[i].pos - pos); + f = -std::abs(mask_kv_self.cells[i].pos - pos); } else { f = 0.0f; } @@ -3780,11 +3821,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (data_swa || data_swa_f16) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; - if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { f = -INFINITY; } } else { - if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { f = -INFINITY; } } @@ -4125,6 +4166,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. +static uint32_t llama_output_embd_width(const llama_context & lctx) { + const auto & hparams = lctx.model.hparams; + if (lctx.cparams.mtp && lctx.model.arch == LLM_ARCH_GEMMA4_MTP && hparams.mtp_backbone_n_embd > 0) { + return hparams.mtp_backbone_n_embd; + } + return hparams.n_embd; +} + +static bool llama_context_has_mtp_outputs(const llama_context & lctx) { + return lctx.cparams.mtp && ( + lctx.model.hparams.nextn_predict_layers > 0 || + lctx.model.arch == LLM_ARCH_GEMMA4 || + lctx.model.arch == LLM_ARCH_GEMMA4_MTP); +} + static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; @@ -4133,10 +4189,10 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = hparams.n_vocab; - const auto n_embd = hparams.n_embd; + const auto n_embd = llama_output_embd_width(lctx); // TODO: use a per-batch flag for logits presence instead - const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp; + const bool has_mtp = llama_context_has_mtp_outputs(lctx); const bool has_logits = !cparams.embeddings || has_mtp; const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp; @@ -4305,7 +4361,8 @@ static int llama_decode_internal( // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - const bool has_mtp = cparams.mtp && hparams.nextn_predict_layers > 0; + const bool has_mtp = llama_context_has_mtp_outputs(lctx); + const uint32_t n_embd_output = llama_output_embd_width(lctx); // count outputs if (batch_all.logits && !embd_pooled) { @@ -4521,7 +4578,8 @@ static int llama_decode_internal( printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); #endif //if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) { - if (u_batch.embd == nullptr && lctx.cparams.graph_reuse) { + if (u_batch.embd == nullptr && lctx.cparams.graph_reuse && + !(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) { prev = std::make_unique(llama_context::Prev{ (int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n, (int)u_batch.n_tokens, cparams.mtp_op_type, gf}); @@ -4546,13 +4604,13 @@ static int llama_decode_internal( res = nullptr; } else { - const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp; - const bool use_qwen_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 || - lctx.model.arch == LLM_ARCH_QWEN35MOE); + const bool has_mtp = llama_context_has_mtp_outputs(lctx); + const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 || + lctx.model.arch == LLM_ARCH_QWEN35MOE || lctx.model.arch == LLM_ARCH_GEMMA4 || lctx.model.arch == LLM_ARCH_GEMMA4_MTP); if (cparams.embeddings || has_mtp) { for (int i = gf->n_nodes - 1; i >= 0; --i) { - if (use_qwen_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) { - // Qwen 3.5 uses raw hidden state before the final shared-head normalization. + if (use_raw_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) { + // MTP recurrent state can be wider/different than the logits head hidden state. embd = gf->nodes[i]; break; } @@ -4565,7 +4623,7 @@ static int llama_decode_internal( } } } - if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) { + if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0 && !has_mtp) { res = nullptr; // do not extract logits for embedding case } else { if (!embd) { // do not extract embeddings when not needed @@ -4667,13 +4725,13 @@ static int llama_decode_internal( { // extract token embeddings GGML_ASSERT(lctx.embd != nullptr); - float * embd_out = lctx.embd + n_outputs_prev_embd*n_embd; - const int32_t n_outputs_new_embd = has_mtp ? n_tokens : lctx.n_outputs; + float * embd_out = lctx.embd + n_outputs_prev_embd*n_embd_output; + const int32_t n_outputs_new_embd = has_mtp ? embd->ne[1] : lctx.n_outputs; if (n_outputs_new_embd) { GGML_ASSERT( n_outputs_prev_embd + n_outputs_new_embd <= n_outputs_embd); - GGML_ASSERT((n_outputs_prev_embd + n_outputs_new_embd)*n_embd <= (int64_t) lctx.embd_size); - ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new_embd*n_embd*sizeof(float)); + GGML_ASSERT((n_outputs_prev_embd + n_outputs_new_embd)*n_embd_output <= (int64_t) lctx.embd_size); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new_embd*n_embd_output*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -4704,7 +4762,7 @@ static int llama_decode_internal( #endif } n_outputs_prev += lctx.n_outputs; - n_outputs_prev_embd += has_mtp ? n_tokens : lctx.n_outputs; + n_outputs_prev_embd += (has_mtp && embd) ? embd->ne[1] : lctx.n_outputs; cur_token += n_tokens; if (reset_previous) { // We need to discard this graph. Otherwise, iwith CUDA graphs enabled, the graph will get resused and this will reset the @@ -6033,7 +6091,8 @@ struct llama_context * llama_init_from_model( } if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 && - model->arch != LLM_ARCH_QWEN35MOE && cparams.mtp != 0) { + model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 && + model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) { cparams.mtp = 0; } @@ -6572,6 +6631,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_SEED_OSS: case LLM_ARCH_STEP35: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_MTP: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -9903,6 +9963,14 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float ctx->draft_input_hidden_state = hidden_state; } +uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) { + return llama_output_embd_width(*ctx); +} + +void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx) { + ctx->mtp_target_ctx = target_ctx; +} + size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) { return unicode_fill_from_utf8((std::string*)utf8, (std::vector*)cpts, (std::vector*)scripts); }