Add MTP Support for Gemma 4 (#1744)

* gemma-mtp: build the arch to load the MTP model

* gemma-mtp: fix mtp kv state

* gemma-mtp: refactor some functions and create gguf

* gemma-mtp: make usable for embeddings models variant

* gemma-mtp: fix qwen mtp load in graph split

* gemma-mtp: refactor tensor creation and adjust output tensor handling

* Gemma 4 MTP: improve tensor handling, and adjust split mode logic
This commit is contained in:
Samuel Oliveira Alves 2026-05-10 01:44:20 -03:00 committed by GitHub
parent ab0f22b819
commit c2b8bca807
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1193 additions and 154 deletions

View File

@ -1101,8 +1101,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else if (value == "suffix") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_SUFFIX;
} else if (value == "mtp") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
params.has_mtp = true;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
throw std::invalid_argument("unknown speculative decoding type");
}
return true;
}
@ -2760,7 +2763,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
" per-step save SSM state per draft step in VRAM; no re-decode on rejection\n"
" gpu-fallback copy state to GPU buffer; re-decode on rejection\n"
" cpu serialise state via llama_state_seq; re-decode on rejection" });
options.push_back({ "*", "--spec-type Name [none | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use when no draft model is provided (default: %d)\n", (int)params.speculative.type});
options.push_back({ "*", "--spec-type Name [none | mtp | ngram - cache | ngram - simple | ngram - map - k | ngram - map - k4v | ngram - mod | suffix]", "type of speculative decoding to use (default: %d)\n", (int)params.speculative.type});
options.push_back({ "*", "--spec-ngram-size-n N", "ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)\n",params.speculative.ngram_size_n });
options.push_back({ "*", "--spec-ngram-size-m N", "ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)\n", params.speculative.ngram_size_m });
@ -3355,11 +3358,9 @@ std::string fs_get_cache_file(const std::string & filename) {
}
//
// Model utils
//
struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
llama_init_result iparams;
auto mparams = common_model_params_to_llama(params);
llama_model * model = nullptr;

View File

@ -19,6 +19,9 @@
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx);
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NONE,
COMMON_SPECULATIVE_TYPE_DRAFT,
@ -154,27 +157,28 @@ struct common_speculative_state_mtp : public common_speculative_state {
llama_context * ctx_tgt;
llama_context * ctx_mtp = nullptr;
common_sampler * smpl;
// For Gemma 4 external MTP assistant: draft positions are held constant
bool constant_draft_positions = false;
common_speculative_state_mtp(
enum common_speculative_type type,
llama_context * ctx_tgt,
const llama_context_params & mtp_cparams)
llama_context * ctx_mtp,
bool constant_draft_positions = false)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_mtp(ctx_mtp)
, constant_draft_positions(constant_draft_positions)
{
struct common_params_sampling params;
params.samplers_sequence = {
struct common_params_sampling sparams;
sparams.samplers_sequence = {
llama_sampler_type::DIST,
};
smpl = common_sampler_init(llama_get_model(ctx_tgt), params);
smpl = common_sampler_init(llama_get_model(ctx_mtp), sparams);
llama_set_mtp_target_context(ctx_mtp, ctx_tgt);
const llama_model * model = llama_get_model(ctx_tgt);
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), mtp_cparams);
if (ctx_mtp) {
LOG_INF("%s: created MTP context (n_ctx=%d)\n", __func__, llama_n_ctx(ctx_mtp));
} else {
LOG_ERR("%s: failed to create MTP context\n", __func__);
}
LOG_INF("%s: MTP context ready (n_ctx=%d, constant_draft_positions=%s)\n", __func__,
llama_n_ctx(ctx_mtp), constant_draft_positions ? "true" : "false");
}
~common_speculative_state_mtp() override {
@ -211,7 +215,8 @@ struct common_speculative_state_mtp : public common_speculative_state {
params.p_min,
id_last,
n_past,
seq_id
seq_id,
constant_draft_positions
);
}
@ -1029,9 +1034,9 @@ common_speculative * common_speculative_init(
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP);
bool has_draft = !params.mparams_dft.path.empty() && !has_mtp;
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@ -1102,15 +1107,20 @@ common_speculative * common_speculative_init(
break;
}
case COMMON_SPECULATIVE_TYPE_MTP: {
auto mtp_state = std::make_unique<common_speculative_state_mtp>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .mtp_cparams = */ params.cparams_dft
);
if (!mtp_state->ctx_mtp) {
LOG_ERR("%s: failed to create MTP context\n", __func__);
return nullptr;
llama_context * ctx_mtp = ctx_dft;
if (!ctx_mtp) {
const llama_model * model = llama_get_model(ctx_tgt);
ctx_mtp = llama_init_from_model(const_cast<llama_model *>(model), params.cparams_dft);
if (!ctx_mtp) {
LOG_ERR("%s: failed to create MTP context\n", __func__);
return nullptr;
}
}
impls.push_back(std::move(mtp_state));
ctx_dft = nullptr;
const bool use_constant_draft_positions = llama_model_is_gemma4_mtp_assistant(llama_get_model(ctx_mtp));
impls.push_back(std::make_unique<common_speculative_state_mtp>(
config.type, ctx_tgt, ctx_mtp, use_constant_draft_positions));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
@ -1224,7 +1234,7 @@ static mtp_last_embd & mtp_get_last_embd(const llama_context * ctx) {
static std::unordered_map<const llama_context *, mtp_last_embd> map;
auto & last = map[ctx];
if (last.embd.empty()) {
auto n_embd = llama_model_n_embd(llama_get_model(ctx));
auto n_embd = llama_mtp_state_n_embd(ctx);
last.embd.resize(n_embd);
}
return last;
@ -1377,7 +1387,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id) {
llama_seq_id seq_id,
bool constant_draft_positions) {
llama_tokens drafts;
drafts.reserve(n_draft);
@ -1394,7 +1405,7 @@ std::vector<llama_token> mtp_speculative_gen_draft(
llama_token current_input_id = id_last;
int32_t current_n_past = n_past;
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_embd = llama_mtp_state_n_embd(ctx);
auto & last = mtp_get_last_embd(ctx);
int i0 = 0;
@ -1415,7 +1426,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
int n_decode = 0;
for (int i = i0; i < n_draft; ++i) {
mtp_batch.n_tokens = 0;
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
const int32_t draft_pos = constant_draft_positions ? n_past : current_n_past;
common_batch_add(mtp_batch, current_input_id, draft_pos, {seq_id}, true);
++n_decode;
if (llama_decode(ctx, mtp_batch) != 0) {

View File

@ -60,7 +60,8 @@ std::vector<llama_token> mtp_speculative_gen_draft(
float p_min,
llama_token id_last,
int32_t n_past,
llama_seq_id seq_id);
llama_seq_id seq_id,
bool constant_draft_positions = false);
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);

View File

@ -3175,6 +3175,267 @@ class Gemma2Model(Model):
return [(self.map_tensor_name(name), data_torch)]
class Gemma4BaseModel(Model):
model_arch = gguf.MODEL_ARCH.GEMMA4
def _text_hparams(self) -> dict[str, Any]:
text_hparams = self.hparams.get("text_config")
if isinstance(text_hparams, dict):
return text_hparams
return self.hparams
def _arch_name(self) -> str:
return gguf.MODEL_ARCH_NAMES[self.model_arch]
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
text_hparams = self.hparams.get("text_config")
if isinstance(text_hparams, dict):
for key in keys:
if key in text_hparams:
return text_hparams[key]
return super().find_hparam(keys, optional)
def set_vocab(self):
vocab = gguf.LlamaHfVocab(self.dir_model)
tokens = []
scores = []
toktypes = []
visible_tokens = {
"<|channel>",
"<channel|>",
"<|tool_call>",
"<tool_call|>",
"<|tool_response>",
"<tool_response|>",
"<|\"|>",
}
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
text_str = text.decode()
if text_str in visible_tokens:
toktypes.append(gguf.TokenType.USER_DEFINED)
logger.info(f"Token {text_str!r} is set to USER_DEFINED")
else:
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size
self.gguf_writer.add_tokenizer_model("gemma4")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_space_prefix(False)
self.gguf_writer.add_add_bos_token(True)
@Model.register("Gemma4ForConditionalGeneration")
class Gemma4Model(Gemma4BaseModel):
model_arch = gguf.MODEL_ARCH.GEMMA4
def set_gguf_parameters(self):
hparams = self._text_hparams()
block_count = hparams["num_hidden_layers"]
arch = self._arch_name()
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)
swa_layers = [layer_type == "sliding_attention" for layer_type in hparams["layer_types"]]
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_sliding_window_pattern(swa_layers)
num_kv_shared_layers = hparams.get("num_kv_shared_layers", 0)
self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers)
n_ff = hparams["intermediate_size"]
if hparams.get("use_double_wide_mlp", False):
first_kv_shared_layer_idx = block_count - num_kv_shared_layers
n_ff_arr = [n_ff if il < first_kv_shared_layer_idx else n_ff * 2 for il in range(block_count)]
self.gguf_writer.add_feed_forward_length(n_ff_arr)
else:
self.gguf_writer.add_feed_forward_length(n_ff)
expert_intermediate_size = hparams.get("expert_intermediate_size") or hparams.get("moe_intermediate_size")
if expert_intermediate_size is not None:
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
n_pl_embd = hparams.get("hidden_size_per_layer_input") or 0
self.gguf_writer.add_embedding_length_per_layer_input(n_pl_embd)
head_dim_full = int(hparams["global_head_dim"])
head_dim_swa = int(hparams["head_dim"])
self.gguf_writer.add_key_length(head_dim_full)
self.gguf_writer.add_value_length(head_dim_full)
self.gguf_writer.add_uint32(f"{arch}.attention.key_length_swa", head_dim_swa)
self.gguf_writer.add_uint32(f"{arch}.attention.value_length_swa", head_dim_swa)
num_kv_full = hparams.get("num_global_key_value_heads")
num_kv_swa = hparams.get("num_key_value_heads")
if num_kv_full is not None and num_kv_swa is not None:
kv_heads = [num_kv_swa if is_swa else num_kv_full for is_swa in swa_layers]
self.gguf_writer.add_head_count_kv(kv_heads)
elif num_kv_swa is not None:
self.gguf_writer.add_head_count_kv(num_kv_swa)
rope_parameters = hparams.get("rope_parameters", {})
rope_full = rope_parameters.get("full_attention", {})
rope_swa = rope_parameters.get("sliding_attention", {})
self.gguf_writer.add_rope_dimension_count(head_dim_full)
partial_rotary_factor_swa = float(rope_swa.get("partial_rotary_factor", hparams.get("partial_rotary_factor", 1.0)))
self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", int(head_dim_swa * partial_rotary_factor_swa))
self.gguf_writer.add_rope_freq_base(float(rope_full.get("rope_theta", 1000000.0)))
self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(rope_swa.get("rope_theta", 10000.0)))
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
hparams = self._text_hparams()
rope_params_full = hparams["rope_parameters"]["full_attention"]
assert rope_params_full["rope_type"] == "proportional"
head_dim_full = int(hparams["global_head_dim"])
partial_rotary_factor_full = rope_params_full["partial_rotary_factor"]
n_rot_full = int(head_dim_full * partial_rotary_factor_full / 2)
n_unrot_full = int(head_dim_full / 2) - n_rot_full
values = [1.0] * n_rot_full + [1e30] * n_unrot_full
rope_freqs_full = torch.tensor(values, dtype=torch.float32)
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith("per_dim_scale") or name.endswith("layer_scalar"):
name = name + ".weight"
if "language_model." not in name and "rope_freqs" not in name:
return []
name = name.replace("language_model.", "")
if name == "lm_head.weight":
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
return []
if name.endswith("router.scale"):
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid, ".scale"), data_torch)]
if ".per_expert_scale" in name:
return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid, ".scale"), data_torch)]
if ".experts." in name and not name.endswith(".weight"):
name += ".weight"
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Gemma4AssistantForCausalLM")
class Gemma4AssistantModel(Gemma4BaseModel):
model_arch = gguf.MODEL_ARCH.GEMMA4_MTP
_root_tensor_map = {
"model.embed_tokens.weight": "token_embd.weight",
"model.norm.weight": "output_norm.weight",
"pre_projection.weight": "mtp_pre_proj.weight",
"post_projection.weight": "mtp_post_proj.weight",
"masked_embedding.centroids.weight": "mtp_centroids.weight",
"masked_embedding.token_ordering": "mtp_token_ordering.weight",
"token_ordering": "mtp_token_ordering.weight",
"token_ordering.weight": "mtp_token_ordering.weight",
"model.token_ordering": "mtp_token_ordering.weight",
"model.token_ordering.weight": "mtp_token_ordering.weight",
"centroids": "mtp_centroids.weight",
"centroids.weight": "mtp_centroids.weight",
"model.centroids": "mtp_centroids.weight",
"model.centroids.weight": "mtp_centroids.weight",
}
_layer_tensor_map = {
"input_layernorm.weight": "attn_norm.weight",
"self_attn.q_proj.weight": "attn_q.weight",
"self_attn.q_norm.weight": "attn_q_norm.weight",
"self_attn.o_proj.weight": "attn_output.weight",
"post_attention_layernorm.weight": "post_attention_norm.weight",
"pre_feedforward_layernorm.weight": "ffn_norm.weight",
"mlp.gate_proj.weight": "ffn_gate.weight",
"mlp.up_proj.weight": "ffn_up.weight",
"mlp.down_proj.weight": "ffn_down.weight",
"post_feedforward_layernorm.weight": "post_ffw_norm.weight",
"layer_scalar": "layer_output_scale.weight",
"layer_scalar.weight": "layer_output_scale.weight",
}
def set_gguf_parameters(self):
hparams = self._text_hparams()
arch = self._arch_name()
sliding_pattern = [layer_type == "sliding_attention" for layer_type in hparams["layer_types"]]
head_dim_swa = int(hparams["head_dim"])
head_dim_full = int(hparams.get("global_head_dim") or head_dim_swa)
n_kv_swa = int(hparams["num_key_value_heads"])
n_kv_full = int(hparams.get("num_global_key_value_heads") or n_kv_swa)
n_kv = [n_kv_swa if is_sliding else n_kv_full for is_sliding in sliding_pattern]
self.gguf_writer.add_context_length(int(hparams["max_position_embeddings"]))
self.gguf_writer.add_embedding_length(int(hparams["hidden_size"]))
self.gguf_writer.add_block_count(int(hparams["num_hidden_layers"]))
self.gguf_writer.add_feed_forward_length(int(hparams["intermediate_size"]))
self.gguf_writer.add_head_count(int(hparams["num_attention_heads"]))
self.gguf_writer.add_head_count_kv(n_kv)
self.gguf_writer.add_key_length(head_dim_full)
self.gguf_writer.add_value_length(head_dim_full)
self.gguf_writer.add_uint32(f"{arch}.attention.key_length_swa", head_dim_swa)
self.gguf_writer.add_uint32(f"{arch}.attention.value_length_swa", head_dim_swa)
self.gguf_writer.add_layer_norm_rms_eps(float(hparams["rms_norm_eps"]))
self.gguf_writer.add_sliding_window(int(hparams["sliding_window"]))
self.gguf_writer.add_array(f"{arch}.attention.sliding_window_pattern", sliding_pattern)
self.gguf_writer.add_rope_dimension_count(head_dim_full)
self.gguf_writer.add_uint32(f"{arch}.rope.dimension_count_swa", head_dim_swa)
rope_parameters = hparams.get("rope_parameters", {})
rope_full = rope_parameters.get("full_attention", {})
rope_swa = rope_parameters.get("sliding_attention", {})
self.gguf_writer.add_rope_freq_base(float(rope_full.get("rope_theta", 1000000.0)))
self.gguf_writer.add_float32(f"{arch}.rope.freq_base_swa", float(rope_swa.get("rope_theta", 10000.0)))
self.gguf_writer.add_uint32(f"{arch}.backbone_embedding_length", int(self.hparams["backbone_hidden_size"]))
self.gguf_writer.add_bool(f"{arch}.use_ordered_embeddings", bool(self.hparams.get("use_ordered_embeddings", False)))
self.gguf_writer.add_uint32(f"{arch}.centroid_count", int(self.hparams.get("num_centroids", 0)))
self.gguf_writer.add_uint32(f"{arch}.centroid_top_k", int(self.hparams.get("centroid_intermediate_top_k", 0)))
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
mapped_name = self._root_tensor_map.get(name)
if mapped_name is not None:
if mapped_name == "mtp_token_ordering.weight":
n_vocab = int(data_torch.shape[0])
n_centroids = int(self.hparams.get("num_centroids", 2048))
tokens_per_centroid = n_vocab // n_centroids
inv_ordering = torch.zeros(n_vocab, dtype=torch.int32)
tok_ord_i32 = data_torch.to(dtype=torch.int64)
inv_ordering[tok_ord_i32] = torch.arange(n_vocab, dtype=torch.int32)
token_to_centroid = (inv_ordering // tokens_per_centroid).to(dtype=torch.int32)
return [(mapped_name, token_to_centroid)]
return [(mapped_name, data_torch)]
prefix = "model.layers."
if not name.startswith(prefix):
raise ValueError(f"Unsupported Gemma 4 assistant tensor: {name}")
layer_id, suffix = name[len(prefix):].split(".", 1)
mapped_suffix = self._layer_tensor_map.get(suffix)
if mapped_suffix is None:
raise ValueError(f"Unsupported Gemma 4 assistant tensor: {name}")
return [(f"blk.{layer_id}.{mapped_suffix}", data_torch)]
@Model.register("Starcoder2ForCausalLM")
class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2

View File

@ -16,6 +16,8 @@
#include <regex>
#include <exception>
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1, int32_t offset = 0) {
if (pos_min == -1) {
pos_min = llama_kv_cache_seq_pos_min(ctx, id);
@ -44,6 +46,97 @@ static void log_text(const gpt_params & params_base, const std::string & text) {
}
}
static bool params_use_gemma4_external_mtp(const gpt_params & params_base) {
return params_base.has_mtp &&
llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft);
}
static llama_context * get_slot_mtp_ctx(server_slot & slot, llama_context * ctx) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
return mtp_ctx ? mtp_ctx : ctx;
}
static int get_ctx_mtp_n_embd(llama_context * ctx) {
return ctx ? (int) llama_mtp_state_n_embd(ctx) : 0;
}
static int get_slot_mtp_n_embd(server_slot & slot, llama_context * ctx) {
return get_ctx_mtp_n_embd(get_slot_mtp_ctx(slot, ctx));
}
static void cache_slot_mtp_hidden(server_slot & slot, const float * hidden, int n_embd) {
if (hidden == nullptr || n_embd <= 0) {
return;
}
slot.mtp_hidden_state.assign(hidden, hidden + n_embd);
}
static void sync_slot_mtp_hidden(server_slot & slot, llama_context * ctx) {
if (!slot.has_mtp || !slot.spec || slot.mtp_hidden_state.empty()) {
return;
}
const int n_embd = get_slot_mtp_n_embd(slot, ctx);
if (n_embd <= 0 || slot.mtp_hidden_state.size() < (size_t) n_embd) {
return;
}
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
}
static void cache_and_sync_slot_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) {
cache_slot_mtp_hidden(slot, hidden, n_embd);
sync_slot_mtp_hidden(slot, ctx);
}
static void cache_and_sync_slot_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector<float> & rows, int n_embd) {
if (rows.empty() || n_embd <= 0) {
return;
}
const size_t n_rows = rows.size() / n_embd;
if (n_rows == 0) {
return;
}
cache_and_sync_slot_mtp_hidden(slot, ctx, rows.data() + (n_rows - 1) * n_embd, n_embd);
}
static void apply_slot_mtp_accept(
server_slot & slot,
llama_context * ctx,
const std::vector<float> & mtp_hidden_state,
const std::vector<llama_token> & ids,
int32_t mtp_n_past_base,
int n_embd) {
if (!slot.has_mtp || mtp_hidden_state.empty() || n_embd <= 0) {
return;
}
if (slot.use_gemma4_external_mtp) {
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, mtp_hidden_state, n_embd);
return;
}
slot.mtp_hidden_state = mtp_hidden_state;
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, mtp_n_past_base, slot.id);
}
static void set_external_mtp_hidden(server_slot & slot, llama_context * ctx, const float * hidden, int n_embd) {
if (!slot.has_mtp || !slot.spec || hidden == nullptr || n_embd <= 0) {
return;
}
cache_and_sync_slot_mtp_hidden(slot, ctx, hidden, n_embd);
}
static void set_external_mtp_hidden_from_rows(server_slot & slot, llama_context * ctx, const std::vector<float> & rows, int n_embd) {
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, rows, n_embd);
}
void server_speculative_checkpoint::clear() {
valid = false;
per_step_enabled = false;
@ -185,6 +278,7 @@ bool server_context::load_model(const gpt_params& params_) {
gpt_params params_dft;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.main_gpu = params_base.main_gpu;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
@ -279,16 +373,22 @@ void server_context::init() {
slot.sparams = params_base.sparams;
if (params_base.has_mtp) {
if (llama_model_n_nextn_layer(model) > 0) {
const bool has_external_mtp = params_use_gemma4_external_mtp(params_base);
if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
if (!has_external_mtp) {
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
}
params_base.speculative.cparams_dft.mtp = true;
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params_base.speculative.cparams_dft.embeddings = true;
slot.has_mtp = true;
slot.use_gemma4_external_mtp = has_external_mtp;
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.params.speculative.n_min = 0;
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
@ -3276,20 +3376,14 @@ void server_context::add_sampled_tokens() {
auto & params_spec = slot.params.speculative;
if (slot.has_mtp) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
if (!slot.mtp_hidden_state.empty()) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
sync_slot_mtp_hidden(slot, ctx);
} else {
LOG_ERROR("MTP hidden state is empty during speculation", {});
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
if (emb_neg1) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
const int n_embd = get_ctx_mtp_n_embd(ctx);
cache_and_sync_slot_mtp_hidden(slot, ctx, emb_neg1, n_embd);
}
}
}
@ -3857,11 +3951,8 @@ static void restore_speculative_checkpoint(
// Update MTP KV cache and hidden state using embeddings collected before checkpoint restore.
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
slot.mtp_hidden_state = mtp_hidden_state_pre;
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
const int n_embd = get_ctx_mtp_n_embd(ctx);
apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd);
}
SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n",
@ -3895,7 +3986,7 @@ static void restore_speculative_checkpoint(
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
}
if (slot.has_mtp) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_embd = get_ctx_mtp_n_embd(ctx);
const int n_accepted = (int)ids.size();
slot.mtp_hidden_state.resize(n_accepted * n_embd);
@ -3906,15 +3997,17 @@ static void restore_speculative_checkpoint(
}
}
llama_context * mtp_ctx_rej = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target_rej = mtp_ctx_rej ? mtp_ctx_rej : ctx;
llama_set_draft_input_hidden_state(mtp_target_rej, slot.mtp_hidden_state.data());
mtp_accept_tokens(mtp_target_rej, ids, slot.spec_ckpt.n_past, slot.id);
if (slot.use_gemma4_external_mtp) {
cache_and_sync_slot_mtp_hidden_from_rows(slot, ctx, slot.mtp_hidden_state, n_embd);
} else {
llama_set_draft_input_hidden_state(get_slot_mtp_ctx(slot, ctx), slot.mtp_hidden_state.data());
mtp_accept_tokens(get_slot_mtp_ctx(slot, ctx), ids, slot.spec_ckpt.n_past, slot.id);
if (n_accepted > 1) {
memmove(slot.mtp_hidden_state.data(),
slot.mtp_hidden_state.data() + (n_accepted - 1) * n_embd,
n_embd * sizeof(float));
if (n_accepted > 1) {
memmove(slot.mtp_hidden_state.data(),
slot.mtp_hidden_state.data() + (n_accepted - 1) * n_embd,
n_embd * sizeof(float));
}
}
slot.mtp_hidden_state.resize(n_embd);
}
@ -3972,7 +4065,7 @@ void server_context::speculative_decoding_accept() {
if (slot.has_mtp) {
mtp_n_past_base = slot.n_past - (slot.drafted.size() + 1);
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_embd = get_ctx_mtp_n_embd(ctx);
if (!ids.empty()) {
mtp_hidden_state_pre.resize(ids.size() * n_embd);
for (size_t i = 0; i < ids.size(); i++) {
@ -4018,13 +4111,9 @@ void server_context::speculative_decoding_accept() {
restore_speculative_checkpoint(slot, ctx, model, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base);
} else {
if (slot.has_mtp && !mtp_hidden_state_pre.empty()) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
slot.mtp_hidden_state = std::move(mtp_hidden_state_pre);
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
mtp_accept_tokens(mtp_target, ids, mtp_n_past_base, slot.id);
}
const int n_embd = get_ctx_mtp_n_embd(ctx);
apply_slot_mtp_accept(slot, ctx, mtp_hidden_state_pre, ids, mtp_n_past_base, n_embd);
}
llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1);
discard_speculative_checkpoint(slot, ctx);
}
@ -4396,8 +4485,18 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
}
bool mtp_warmup_needed = false;
llama_context * batch_mtp_target = nullptr;
std::vector<float> batch_mtp_hidden_state;
if (params_base.has_mtp) {
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp) {
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
if (mc) {
batch_mtp_target = mc;
break;
}
}
}
for (auto& slot : slots) {
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
@ -4409,13 +4508,16 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
}
}
if (mtp_warmup_needed) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
const int n_embd_src = get_ctx_mtp_n_embd(ctx);
const int n_embd_dst = get_ctx_mtp_n_embd(mtp_target);
const int n_toks = batch_view.n_tokens;
batch_mtp_hidden_state.resize(n_toks * n_embd);
batch_mtp_hidden_state.assign(n_toks * n_embd_dst, 0.0f);
for (int t = 0; t < n_toks; t++) {
const float* emb_t = llama_get_embeddings_ith(ctx, t);
if (emb_t) {
memcpy(batch_mtp_hidden_state.data() + t * n_embd, emb_t, n_embd * sizeof(float));
const int n_copy = std::min(n_embd_src, n_embd_dst);
memcpy(batch_mtp_hidden_state.data() + t * n_embd_dst, emb_t, n_copy * sizeof(float));
}
}
}
@ -4469,9 +4571,12 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
if (params_base.has_mtp && slot.n_decoded == 0) {
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
if (emb_i) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb_i, n_embd * sizeof(float));
const int n_embd = get_ctx_mtp_n_embd(ctx);
if (slot.use_gemma4_external_mtp) {
set_external_mtp_hidden(slot, ctx, emb_i, n_embd);
} else {
cache_slot_mtp_hidden(slot, emb_i, n_embd);
}
}
}
@ -4537,16 +4642,17 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
slot.i_batch = -1;
}
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
llama_context * mtp_ctx = nullptr;
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp) {
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
if (mc) { mtp_ctx = mc; break; }
if (params_use_gemma4_external_mtp(params_base)) {
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp && !slot.mtp_hidden_state.empty()) {
sync_slot_mtp_hidden(slot, ctx);
}
}
} else {
llama_context * mtp_target = batch_mtp_target ? batch_mtp_target : ctx;
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
mtp_update_kv_cache(mtp_target, batch_view, true);
}
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
mtp_update_kv_cache(mtp_target, batch_view, true);
}
// speculative decoding - main model sample and accept

View File

@ -171,6 +171,7 @@ struct server_slot {
decltype(ctx_sampling->elb_states) elb_prev_states;
bool has_mtp = false;
bool use_gemma4_external_mtp = false;
std::vector<float> mtp_hidden_state;
// saves recurrent state before a speculative batch so it can be restored on rejection

View File

@ -1107,7 +1107,8 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
// Gemma4 with per-layer embeddings and uses up to 32 inputs
#define GGML_SCHED_MAX_SPLIT_INPUTS 32
#endif
#ifndef GGML_SCHED_MAX_COPIES

View File

@ -234,6 +234,8 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
GEMMA4 = auto()
GEMMA4_MTP = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
@ -282,7 +284,10 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
FFN_PRE_NORM = auto()
FFN_PRE_NORM_2 = auto()
FFN_POST_NORM = auto()
FFN_POST_NORM_1 = auto()
FFN_POST_NORM_2 = auto()
FFN_GATE = auto()
FFN_DOWN = auto()
FFN_UP = auto()
@ -291,6 +296,7 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
FFN_GATE_UP_EXP = auto()
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
@ -298,6 +304,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
LAYER_OUT_SCALE = auto()
PER_LAYER_TOKEN_EMBD = auto()
PER_LAYER_MODEL_PROJ = auto()
PER_LAYER_INP_GATE = auto()
PER_LAYER_PROJ = auto()
PER_LAYER_PROJ_NORM = auto()
PER_LAYER_POST_NORM = auto()
SSM_IN = auto()
SSM_CONV1D = auto()
SSM_X = auto()
@ -349,6 +362,10 @@ class MODEL_TENSOR(IntEnum):
NEXTN_HNORM = auto() # nextn tensors (glm4moe)
NEXTN_SHARED_HEAD_HEAD = auto() # nextn tensors (glm4moe)
NEXTN_SHARED_HEAD_NORM = auto() # nextn tensors (glm4moe)
MTP_PRE_PROJ = auto()
MTP_POST_PROJ = auto()
MTP_TOKEN_ORDERING = auto()
MTP_CENTROIDS = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -383,6 +400,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_MTP: "gemma4_mtp",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
@ -434,7 +453,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_PRE_NORM_2: "blk.{bid}.pre_ffw_norm_2",
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
MODEL_TENSOR.FFN_POST_NORM_1: "blk.{bid}.post_ffw_norm_1",
MODEL_TENSOR.FFN_POST_NORM_2: "blk.{bid}.post_ffw_norm_2",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
@ -446,8 +468,16 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.LAYER_OUT_SCALE: "blk.{bid}.layer_output_scale",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd",
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj",
MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate",
MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj",
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm",
MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm",
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
@ -500,6 +530,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
MODEL_TENSOR.MTP_PRE_PROJ: "mtp_pre_proj",
MODEL_TENSOR.MTP_POST_PROJ: "mtp_post_proj",
MODEL_TENSOR.MTP_TOKEN_ORDERING: "mtp_token_ordering",
MODEL_TENSOR.MTP_CENTROIDS: "mtp_centroids",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -962,6 +996,56 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GEMMA4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_PRE_NORM_2,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.FFN_POST_NORM_1,
MODEL_TENSOR.FFN_POST_NORM_2,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.LAYER_OUT_SCALE,
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
MODEL_TENSOR.PER_LAYER_INP_GATE,
MODEL_TENSOR.PER_LAYER_PROJ,
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
],
MODEL_ARCH.GEMMA4_MTP: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_SCALE,
MODEL_TENSOR.MTP_PRE_PROJ,
MODEL_TENSOR.MTP_POST_PROJ,
MODEL_TENSOR.MTP_TOKEN_ORDERING,
MODEL_TENSOR.MTP_CENTROIDS,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -245,12 +245,24 @@ class TensorNameMap:
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
),
MODEL_TENSOR.FFN_PRE_NORM_2: (
"model.layers.{bid}.pre_feedforward_layernorm_2", # gemma4
),
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
"model.layers.{bid}.post_moe_norm", # grok-2
),
MODEL_TENSOR.FFN_POST_NORM_1: (
"model.layers.{bid}.post_feedforward_layernorm_1", # gemma4
),
MODEL_TENSOR.FFN_POST_NORM_2: (
"model.layers.{bid}.post_feedforward_layernorm_2", # gemma4
),
MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
@ -305,6 +317,11 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
),
MODEL_TENSOR.FFN_GATE_UP_EXP: (
"model.layers.{bid}.mlp.experts.gate_up_proj", # gemma4
"model.layers.{bid}.experts.gate_up_proj", # gemma4
),
MODEL_TENSOR.FFN_UP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
@ -413,6 +430,34 @@ class TensorNameMap:
"model.layers.{bid}.final_layernorm", # bailingmoe2
),
MODEL_TENSOR.LAYER_OUT_SCALE: (
"model.layers.{bid}.layer_scalar", # gemma4
),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
"model.embed_tokens_per_layer", # gemma4
),
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: (
"model.per_layer_model_projection", # gemma4
),
MODEL_TENSOR.PER_LAYER_PROJ_NORM: (
"model.per_layer_projection_norm", # gemma4
),
MODEL_TENSOR.PER_LAYER_INP_GATE: (
"model.layers.{bid}.per_layer_input_gate", # gemma4
),
MODEL_TENSOR.PER_LAYER_PROJ: (
"model.layers.{bid}.per_layer_projection", # gemma4
),
MODEL_TENSOR.PER_LAYER_POST_NORM: (
"model.layers.{bid}.post_per_layer_input_norm", # gemma4
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",

View File

@ -685,6 +685,11 @@ extern "C" {
LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model);
// Returns true if the model is a Gemma 4 MTP assistant (external frozen-KV speculative drafter)
LLAMA_API bool llama_model_is_gemma4_mtp_assistant(const struct llama_model * model);
LLAMA_API bool llama_is_gemma4_mtp_file(const char * path);
LLAMA_API bool llama_model_is_split_mode_graph(const struct llama_model * model);
// Returns 0 on success

View File

@ -2,6 +2,128 @@
#include "../llama-model.h"
#include "../llama-context.h"
static int gemma4_mtp_target_kv_layer(const llama_hparams & mtp_hparams, const llama_hparams & target_hparams, int mtp_il) {
GGML_ASSERT(mtp_il >= 0 && mtp_il < (int) mtp_hparams.n_layer);
const bool is_sliding = mtp_hparams.swa_layers[mtp_il] != 0;
const int target_n_kv_layer = target_hparams.n_layer_kv_from_start > 0
? std::min<int>((int) target_hparams.n_layer, target_hparams.n_layer_kv_from_start)
: (int) target_hparams.n_layer;
int target_il = target_n_kv_layer - 1;
for (; target_il >= 0; --target_il) {
if ((target_hparams.swa_layers[target_il] != 0) == is_sliding) {
break;
}
}
GGML_ASSERT(target_il >= 0 && "Gemma4 MTP could not find a matching target KV layer");
return target_il;
}
static void gemma4_mtp_prepare_frozen_kv_views(
ggml_context * ctx0,
llama_context & lctx,
const llama_kv_cache & target_kv,
int assistant_il,
int target_il,
int32_t target_n_kv,
ggml_tensor ** frozen_k,
ggml_tensor ** frozen_v,
const llm_build_cb & cb) {
if (*frozen_k || *frozen_v) {
GGML_ASSERT(*frozen_k && *frozen_v);
return;
}
if (!lctx.cparams.flash_attn) {
return;
}
GGML_ASSERT(target_il >= 0 && target_il < (int) target_kv.k_l.size() && target_il < (int) target_kv.v_l.size());
ggml_tensor * k_cache = target_kv.k_l[target_il];
ggml_tensor * v_cache = target_kv.v_l[target_il];
if (!k_cache || !v_cache || !k_cache->extra || !v_cache->extra) {
return;
}
auto * split_k = (ggml_split_tensor_t *) k_cache->extra;
auto * split_v = (ggml_split_tensor_t *) v_cache->extra;
GGML_ASSERT(split_k && split_v);
GGML_ASSERT(split_k->n_device == split_v->n_device);
const llama_hparams & assistant_hparams = lctx.model.hparams;
const int64_t n_embd_head_k = assistant_hparams.n_embd_head_k(assistant_il);
const int64_t n_embd_head_v = assistant_hparams.n_embd_head_v(assistant_il);
std::vector<ggml_tensor *> k_parts;
std::vector<ggml_tensor *> v_parts;
k_parts.reserve(split_k->n_device);
v_parts.reserve(split_v->n_device);
for (int id = 0; id < split_k->n_device; ++id) {
ggml_tensor * split_kl = split_k->splits[id];
ggml_tensor * split_vl = split_v->splits[id];
GGML_ASSERT((split_kl && split_vl) || (!split_kl && !split_vl));
if (!split_kl) {
continue;
}
GGML_ASSERT(target_kv.size > 0);
GGML_ASSERT(split_kl->ne[1] % target_kv.size == 0);
const int64_t split_n_head_kv = split_kl->ne[1] / target_kv.size;
ggml_tensor * k_part = ggml_view_3d(ctx0, split_kl,
n_embd_head_k, target_n_kv, split_n_head_kv,
ggml_row_size(split_kl->type, n_embd_head_k) * split_n_head_kv,
ggml_row_size(split_kl->type, n_embd_head_k),
0);
if (k_part->type != GGML_TYPE_F32) {
k_part = ggml_cast(ctx0, k_part, GGML_TYPE_F32);
}
cb(k_part, "mtp_frozen_k_split", 1000 * (assistant_il + 1) + id);
ggml_tensor * v_part = ggml_view_3d(ctx0, split_vl,
n_embd_head_v, target_n_kv, split_n_head_kv,
ggml_row_size(split_vl->type, split_n_head_kv * n_embd_head_v),
ggml_row_size(split_vl->type, n_embd_head_v),
0);
if (v_part->type != GGML_TYPE_F32) {
v_part = ggml_cast(ctx0, v_part, GGML_TYPE_F32);
}
cb(v_part, "mtp_frozen_v_split", 1000 * (assistant_il + 1) + id);
k_parts.push_back(k_part);
v_parts.push_back(v_part);
}
GGML_ASSERT(!k_parts.empty() && k_parts.size() == v_parts.size());
ggml_tensor * k_full = k_parts[0];
ggml_tensor * v_full = v_parts[0];
for (size_t i = 1; i < k_parts.size(); ++i) {
k_full = ggml_concat(ctx0, k_full, k_parts[i], 2);
v_full = ggml_concat(ctx0, v_full, v_parts[i], 2);
}
if (k_full->type != GGML_TYPE_F16) {
k_full = ggml_cast(ctx0, k_full, GGML_TYPE_F16);
}
if (v_full->type != GGML_TYPE_F16) {
v_full = ggml_cast(ctx0, v_full, GGML_TYPE_F16);
}
cb(k_full, "mtp_frozen_k", assistant_il);
cb(v_full, "mtp_frozen_v", assistant_il);
*frozen_k = k_full;
*frozen_v = v_full;
}
static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_context & lctx, ggml_context * ctx0,
ggml_tensor * inpL, ggml_tensor * inp_pos, ggml_tensor * inp_out_ids,
ggml_tensor * KQ_mask, ggml_tensor * KQ_mask_swa, int n_tokens, const llm_build_cb & cb) {
@ -363,6 +485,160 @@ static ggml_cgraph * build_gemma4_graph_parallel(llm_build_context & llm, llama_
return gf;
}
ggml_cgraph * llm_build_context::build_gemma4_mtp() {
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(n_tokens), false);
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_backbone = hparams.mtp_backbone_n_embd;
const int32_t n_layer = hparams.n_layer;
const bool has_target_ctx = lctx.mtp_target_ctx != nullptr;
GGML_ASSERT(n_backbone > 0);
lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch.n_tokens);
cb(lctx.inp_tokens, "inp_tokens", -1);
ggml_set_input(lctx.inp_tokens);
ggml_tensor * hidden_state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_backbone, n_tokens);
ggml_set_name(hidden_state, "inp_mtp_states");
ggml_set_input(hidden_state);
lctx.inp_mtp_states = hidden_state;
if (!has_target_ctx || !batch.token) {
ggml_tensor * cur = ggml_view_2d(ctx0, hidden_state, n_embd, n_tokens,
ggml_row_size(hidden_state->type, n_backbone), 0);
cb(cur, "mtp_init_hidden_view", -1);
ggml_tensor * mtp_embd = ggml_dup(ctx0, hidden_state);
cb(mtp_embd, "result_mtp_embd", -1);
ggml_build_forward_expand(gf, mtp_embd);
ggml_tensor * logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
cb(logits, "result_output", -1);
ggml_build_forward_expand(gf, logits);
GGML_UNUSED(n_vocab);
return gf;
}
const llama_model & target_model = lctx.mtp_target_ctx->model;
const llama_hparams & target_hparams = target_model.hparams;
const llama_cparams & target_cparams = lctx.mtp_target_ctx->cparams;
const llama_kv_cache & target_kv = lctx.mtp_target_ctx->kv_self;
GGML_ASSERT(n_tokens <= target_kv.n);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * token_embd = ggml_get_rows(ctx0, target_model.tok_embd, lctx.inp_tokens);
cb(token_embd, "inp_embd_target", -1);
token_embd = ggml_scale(ctx0, token_embd, std::sqrt(float(n_backbone)));
cb(token_embd, "inp_embd_scaled", -1);
ggml_tensor * cur = ggml_concat(ctx0, token_embd, hidden_state, 0);
cb(cur, "inp_mtp_combined", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.mtp_pre_proj, cur);
cb(cur, "mtp_pre_proj", -1);
const int32_t target_n_kv = target_kv.n;
const int32_t target_kv_head = target_kv.head;
ggml_tensor * KQ_mask = nullptr;
ggml_tensor * KQ_mask_swa = nullptr;
ggml_tensor * frozen_k_swa = nullptr;
ggml_tensor * frozen_v_swa = nullptr;
ggml_tensor * frozen_k_full = nullptr;
ggml_tensor * frozen_v_full = nullptr;
{
const int64_t n_mask_tokens = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32, target_n_kv, n_mask_tokens);
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
KQ_mask = lctx.inp_KQ_mask;
if (target_hparams.n_swa > 0) {
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32, target_n_kv, n_mask_tokens);
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(lctx.inp_KQ_mask_swa);
KQ_mask_swa = lctx.inp_KQ_mask_swa;
}
}
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpL = cur;
const bool is_sliding = hparams.swa_layers[il] ? true : false;
const float freq_base_l = is_sliding ? target_hparams.rope_freq_base_train_swa : target_cparams.rope_freq_base;
const float freq_scale_l = is_sliding ? target_hparams.rope_freq_scale_train_swa : target_cparams.rope_freq_scale;
const int n_rot_l = is_sliding ? target_hparams.n_rot_swa : target_hparams.n_rot;
const int n_swa = is_sliding ? target_hparams.n_swa : 0;
const int n_embd_head = hparams.n_embd_head_k(il);
const int n_head = hparams.n_head(il);
ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_rope", il);
const int target_il = gemma4_mtp_target_kv_layer(hparams, target_hparams, il);
ggml_tensor *& frozen_k = is_sliding ? frozen_k_swa : frozen_k_full;
ggml_tensor *& frozen_v = is_sliding ? frozen_v_swa : frozen_v_full;
gemma4_mtp_prepare_frozen_kv_views(ctx0, lctx, target_kv, il, target_il, target_n_kv, &frozen_k, &frozen_v, cb);
cur = llm_build_kv(ctx0, lctx, target_kv, gf, model.layers[il].wo, model.layers[il].bo,
nullptr, nullptr, Qcur, KQ_mask_l, n_tokens, target_kv_head, target_n_kv, hparams.f_attention_scale, cb, il, nullptr, n_swa, target_il,
&frozen_k, &frozen_v);
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "attn_out", il);
ggml_tensor * ffn = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_GELU, LLM_FFN_PAR, cb, il, gf, true, false, nullptr, model.layers[il].ffn_post_norm);
cb(ffn, "ffn_out", il);
cur = ffn;
if (model.layers[il].out_scale) {
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
cb(cur, "out_scaled", il);
}
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
}
ggml_tensor * mtp_embd = llm_build_lora_mm(lctx, ctx0, model.mtp_post_proj, cur);
cb(mtp_embd, "result_mtp_embd", -1);
ggml_build_forward_expand(gf, mtp_embd);
ggml_tensor * logits;
// E2B/E4B: The centroid/token-ordering tensors are kept in the GGUF for future use but
// not required for correct inference — the full-vocab matmul against the tied output
// weight still yields valid per-token logits.
{
logits = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
cb(logits, "result_output", -1);
}
ggml_build_forward_expand(gf, logits);
GGML_UNUSED(n_embd);
GGML_UNUSED(n_vocab);
return gf;
}
static ggml_tensor * gemma4_project_per_layer_inputs(ggml_context * ctx0, const llama_model & model, const llm_build_cb & cb,
int n_embd, int n_embd_per_layer, int n_layer, int n_tokens,
ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
@ -614,6 +890,12 @@ ggml_cgraph * llm_build_context::build_gemma4() {
cur = inpL;
if (cparams.mtp) {
ggml_tensor * mtp_embd = ggml_dup(ctx0, cur);
cb(mtp_embd, "result_mtp_embd", -1);
ggml_build_forward_expand(gf, mtp_embd);
}
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);

View File

@ -78,6 +78,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GLM_DSA, "glm-dsa" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_MTP, "gemma4_mtp" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -140,6 +141,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SWIGLU_LIMITS_SHARED, "%s.swiglu_limits_shared" },
{ LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" },
{ LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" },
{ LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, "%s.backbone_embedding_length" },
{ LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, "%s.use_ordered_embeddings" },
{ LLM_KV_MTP_CENTROID_COUNT, "%s.centroid_count" },
{ LLM_KV_MTP_CENTROID_TOP_K, "%s.centroid_top_k" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },

View File

@ -77,6 +77,7 @@ enum llm_arch {
LLM_ARCH_GLM_DSA,
LLM_ARCH_MISTRAL4,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_MTP,
LLM_ARCH_UNKNOWN,
};
@ -133,6 +134,10 @@ enum llm_kv {
LLM_KV_SWIGLU_CLAMP_EXP,
LLM_KV_SWIGLU_CLAMP_SHEXP,
LLM_KV_EMBEDDING_LENGTH_PER_LAYER,
LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH,
LLM_KV_MTP_USE_ORDERED_EMBEDDINGS,
LLM_KV_MTP_CENTROID_COUNT,
LLM_KV_MTP_CENTROID_TOP_K,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -358,6 +363,10 @@ enum llm_tensor {
LLM_TENSOR_FFN_PRE_NORM_2, // 105
LLM_TENSOR_FFN_POST_NORM_1,
LLM_TENSOR_FFN_POST_NORM_2,
LLM_TENSOR_MTP_PRE_PROJ,
LLM_TENSOR_MTP_POST_PROJ,
LLM_TENSOR_MTP_TOKEN_ORDERING,
LLM_TENSOR_MTP_CENTROIDS,
LLM_TENSOR_UNKNOWN,
};

View File

@ -1537,37 +1537,46 @@ static ggml_tensor * llm_build_kqv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr, int n_swa = 0) {
ggml_tensor * sinks = nullptr, int n_swa = 0, int kv_il = -1,
ggml_tensor ** k_cache_view = nullptr, ggml_tensor ** v_cache_view = nullptr) {
const llama_model & model = lctx.model;
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_ctx = kv.size;
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k(il);
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_head_v = hparams.n_embd_head_v(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const int kv_layer = kv_il >= 0 ? kv_il : il;
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
auto k_cache = lctx.model.hparams.has_kv(il) ? kv.k_l[il]
auto k_cache = kv_il >= 0 ? kv.k_l[kv_layer]
: lctx.model.hparams.has_kv(il) ? kv.k_l[il]
: lctx.model.hparams.swa_layers[il] ? kv.k_l[hparams.n_layer_kv_from_start-2] : kv.k_l[hparams.n_layer_kv_from_start-1];
auto v_cache = lctx.model.hparams.has_kv(il) ? kv.v_l[il]
auto v_cache = kv_il >= 0 ? kv.v_l[kv_layer]
: lctx.model.hparams.has_kv(il) ? kv.v_l[il]
: lctx.model.hparams.swa_layers[il] ? kv.v_l[hparams.n_layer_kv_from_start-2] : kv.v_l[hparams.n_layer_kv_from_start-1];
GGML_ASSERT(k_cache != nullptr && "k_cache is null in llm_build_kqv");
GGML_ASSERT(v_cache != nullptr && "v_cache is null in llm_build_kqv");
struct ggml_tensor * k =
ggml_view_3d(ctx, k_cache,
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(k_cache->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(k_cache->type, n_embd_head_k),
0);
cb(k, "k", il);
struct ggml_tensor * k = k_cache_view ? *k_cache_view : nullptr;
if (!k) {
k = ggml_view_3d(ctx, k_cache,
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(k_cache->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(k_cache->type, n_embd_head_k),
0);
if (k_cache_view) {
*k_cache_view = k;
}
cb(k, "k", il);
}
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_precision = true;
@ -1594,13 +1603,18 @@ static ggml_tensor * llm_build_kqv(
GGML_UNUSED(n_ctx);
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, v_cache,
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(v_cache->type, n_embd_v_gqa),
ggml_row_size(v_cache->type, n_embd_head_v),
0);
cb(v, "v", il);
struct ggml_tensor * v = v_cache_view ? *v_cache_view : nullptr;
if (!v) {
v = ggml_view_3d(ctx, v_cache,
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(v_cache->type, n_embd_v_gqa),
ggml_row_size(v_cache->type, n_embd_head_v),
0);
if (v_cache_view) {
*v_cache_view = v;
}
cb(v, "v", il);
}
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
@ -1626,22 +1640,27 @@ static ggml_tensor * llm_build_kqv(
} else {
// split cached v into n_head heads
struct ggml_tensor * v;
if (kv.v_trans) {
v = ggml_view_3d(ctx, v_cache,
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(v_cache)*n_ctx,
ggml_element_size(v_cache)*n_ctx*n_embd_head_v,
0);
} else {
v = ggml_view_3d(ctx, v_cache,
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(v_cache->type, n_embd_v_gqa),
ggml_row_size(v_cache->type, n_embd_head_v),
0);
v = ggml_cont(ctx, ggml_transpose(ctx, v));
struct ggml_tensor * v = v_cache_view ? *v_cache_view : nullptr;
if (!v) {
if (kv.v_trans) {
v = ggml_view_3d(ctx, v_cache,
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(v_cache)*n_ctx,
ggml_element_size(v_cache)*n_ctx*n_embd_head_v,
0);
} else {
v = ggml_view_3d(ctx, v_cache,
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(v_cache->type, n_embd_v_gqa),
ggml_row_size(v_cache->type, n_embd_head_v),
0);
v = ggml_cont(ctx, ggml_transpose(ctx, v));
}
if (v_cache_view) {
*v_cache_view = v;
}
cb(v, "v", il);
}
cb(v, "v", il);
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
@ -1775,7 +1794,8 @@ ggml_tensor * llm_build_context::llm_build_kv(
int32_t kv_head,
int32_t n_kv,
float kq_scale,
const llm_build_cb & cb, int il, ggml_tensor * sinks, int n_swa) {
const llm_build_cb & cb, int il, ggml_tensor * sinks, int n_swa, int kv_il,
ggml_tensor ** k_cache_view, ggml_tensor ** v_cache_view) {
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@ -1805,7 +1825,8 @@ ggml_tensor * llm_build_context::llm_build_kv(
llm_build_kv_store(lctx, ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
}
auto cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa);
auto cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa, kv_il,
k_cache_view, v_cache_view);
cb(cur, "kqv_out", il);
return cur;
@ -2332,6 +2353,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_gemma4();
} break;
case LLM_ARCH_GEMMA4_MTP:
{
result = llm.build_gemma4_mtp();
} break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();

View File

@ -240,6 +240,8 @@ struct llm_build_context {
ggml_cgraph * build_gemma4();
ggml_cgraph * build_gemma4_mtp();
ggml_cgraph * build_starcoder2();
ggml_cgraph * build_mamba();
@ -339,7 +341,8 @@ struct llm_build_context {
int32_t kv_head,
int32_t n_kv,
float kq_scale,
const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0);
const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0, int kv_il = -1,
ggml_tensor ** k_cache_view = nullptr, ggml_tensor ** v_cache_view = nullptr);
static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx, ggml_tensor * ffn_norm,
ggml_tensor * cur,

View File

@ -199,6 +199,7 @@ struct llama_context {
struct llama_cparams cparams;
struct llama_sampling sampling;
struct llama_kv_cache kv_self;
struct llama_context * mtp_target_ctx = nullptr;
struct llama_control_vector cvec;
std::vector<float> scale_data;

View File

@ -754,6 +754,26 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA4_MTP:
{
ml.get_key(LLM_KV_MTP_BACKBONE_EMBEDDING_LENGTH, hparams.mtp_backbone_n_embd);
ml.get_key(LLM_KV_MTP_USE_ORDERED_EMBEDDINGS, hparams.mtp_use_ordered_embeddings, false);
ml.get_key(LLM_KV_MTP_CENTROID_COUNT, hparams.mtp_num_centroids, false);
ml.get_key(LLM_KV_MTP_CENTROID_TOP_K, hparams.mtp_centroid_top_k, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
hparams.n_layer_kv_from_start = hparams.n_layer;
hparams.f_attention_scale = 1.0f;
switch (hparams.mtp_backbone_n_embd) {
case 5376: model.type = e_model::MODEL_32B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_STARCODER2:
{

View File

@ -134,6 +134,12 @@ struct llama_hparams {
// gemma4 per-layer embedding
uint32_t n_embd_per_layer = 0;
// gemma4 separate assistant MTP
uint32_t mtp_backbone_n_embd = 0;
bool mtp_use_ordered_embeddings = false;
uint32_t mtp_num_centroids = 0;
uint32_t mtp_centroid_top_k = 0;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = -1;
@ -152,6 +158,7 @@ struct llama_hparams {
if (this->n_vocab != other.n_vocab) return true;
if (this->n_ctx_train != other.n_ctx_train) return true;
if (this->n_embd != other.n_embd) return true;
if (this->mtp_backbone_n_embd != other.mtp_backbone_n_embd) return true;
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
if (this->n_swa != other.n_swa) return true;

View File

@ -95,6 +95,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
bool create_gemma4_tensors(const LLM_TN & tn);
bool create_gemma4_mtp_tensors(const LLM_TN & tn);
bool create_starcoder2_tensors(const LLM_TN & tn);
bool create_mamba_tensors(const LLM_TN & tn);
@ -2016,6 +2018,7 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) {
const uint32_t n_embd_per_layer = hparams.n_embd_per_layer;
const int64_t n_ff_exp = hparams.n_ff_exp;
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
if (n_embd_head_k != n_embd_head_v) {
throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v");
@ -2043,7 +2046,8 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) {
int rope_freqs_flag = 0;
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_split = ctx_for_layer_split(i);
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_layer;
auto & layer = model.layers[i];
const int64_t n_head = hparams.n_head(i);
const int64_t n_embd_head = hparams.n_embd_head_k(i);
@ -2110,6 +2114,53 @@ bool create_tensors_helper::create_gemma4_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::create_gemma4_mtp_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
const int64_t n_backbone = hparams.mtp_backbone_n_embd;
const bool use_split_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN;
if (n_backbone <= 0) {
throw std::runtime_error("Gemma 4 MTP assistant requires backbone_embedding_length metadata");
}
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (model.output == NULL) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
model.mtp_pre_proj = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_PRE_PROJ, "weight"), {2*n_backbone, n_embd}, 0);
model.mtp_post_proj = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_POST_PROJ, "weight"), {n_embd, n_backbone}, 0);
model.mtp_token_ordering = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_TOKEN_ORDERING, "weight"), {n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.mtp_centroids = create_tensor(ctx_output, tn(LLM_TENSOR_MTP_CENTROIDS, "weight"), {n_embd, hparams.mtp_num_centroids}, llama_model_loader::TENSOR_NOT_REQUIRED);
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = use_split_ctx ? ctx_for_layer_split(i) : ctx_layer;
auto & layer = model.layers[i];
const int64_t n_head = hparams.n_head(i);
const int64_t n_embd_head = hparams.n_embd_head_k(i);
const int64_t n_ff_cur = hparams.n_ff(i);
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head*n_head}, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head*n_head, n_embd}, 0);
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.out_scale = create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
layer.ffn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
return use_mmap_buffer;
}
bool create_tensors_helper::create_starcoder2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
@ -4071,6 +4122,8 @@ bool create_tensors_helper::create_tensors() {
use_mmap_buffer = create_gemma_tensors(tn, 3); break;
case LLM_ARCH_GEMMA4:
use_mmap_buffer = create_gemma4_tensors(tn); break;
case LLM_ARCH_GEMMA4_MTP:
use_mmap_buffer = create_gemma4_mtp_tensors(tn); break;
case LLM_ARCH_STARCODER2:
use_mmap_buffer = create_starcoder2_tensors(tn); break;
case LLM_ARCH_MAMBA:
@ -4140,14 +4193,15 @@ bool create_tensors_helper::create_tensors() {
use_mmap_buffer &= !has_buft_overrides;
}
if (model.arch == LLM_ARCH_GEMMA4 && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
bool supported = true;
if (model.tok_embd_per_layer) {
supported = false;
}
if (!supported) {
{
const bool unsupported =
(model.arch == LLM_ARCH_GEMMA4_MTP) ||
(model.arch == LLM_ARCH_GEMMA4 && model.tok_embd_per_layer);
if (unsupported && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
LLAMA_LOG_WARN("\n=========================================================\n");
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this Gemma4 variant\n");
LLAMA_LOG_WARN("Split mode 'graph' is not supported for %s\n",
model.arch == LLM_ARCH_GEMMA4_MTP ? "Gemma 4 MTP assistant"
: "this Gemma4 variant");
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
LLAMA_LOG_WARN("===========================================================\n\n");
model.split_mode = LLAMA_SPLIT_MODE_LAYER;

View File

@ -803,6 +803,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
},
},
{
LLM_ARCH_GEMMA4_MTP,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" },
{ LLM_TENSOR_MTP_PRE_PROJ, "mtp_pre_proj" },
{ LLM_TENSOR_MTP_POST_PROJ, "mtp_post_proj" },
{ LLM_TENSOR_MTP_TOKEN_ORDERING, "mtp_token_ordering" },
{ LLM_TENSOR_MTP_CENTROIDS, "mtp_centroids" },
},
},
{
LLM_ARCH_STARCODER2,
{
@ -1881,6 +1903,27 @@ bool llama_model_has_recurrent(const llama_model * model) {
return llm_arch_is_hybrid(model->arch) || llm_arch_is_recurrent(model->arch);
}
bool llama_model_is_gemma4_mtp_assistant(const llama_model * model) {
return model && model->arch == LLM_ARCH_GEMMA4_MTP;
}
bool llama_is_gemma4_mtp_file(const char * path) {
if (!path || !*path) return false;
struct gguf_init_params params = { /*.no_alloc =*/ true, /*.ctx =*/ nullptr };
struct gguf_context * ctx = gguf_init_from_file(path, params);
if (!ctx) return false;
bool result = false;
const int key_id = gguf_find_key(ctx, "general.architecture");
if (key_id >= 0) {
const char * arch = gguf_get_val_str(ctx, key_id);
if (arch && strcmp(arch, "gemma4_mtp") == 0) {
result = true;
}
}
gguf_free(ctx);
return result;
}
bool llama_model_is_split_mode_graph(const struct llama_model * model) {
return model && (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN);
}

View File

@ -405,6 +405,11 @@ struct llama_model {
struct ggml_tensor * per_layer_model_proj = nullptr;
struct ggml_tensor * per_layer_proj_norm = nullptr;
struct ggml_tensor * mtp_pre_proj = nullptr;
struct ggml_tensor * mtp_post_proj = nullptr;
struct ggml_tensor * mtp_token_ordering = nullptr;
struct ggml_tensor * mtp_centroids = nullptr;
struct ggml_tensor * output_norm;
struct ggml_tensor * output_norm_b;
struct ggml_tensor * output;

View File

@ -25,6 +25,9 @@
#include "ggml-alloc.h"
#include "ggml-backend.h"
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx);
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx);
// TODO: fix these includes
#include "iqk/iqk_quantize.h"
#include "iqk/iqk_cpu_ops.h"
@ -562,6 +565,7 @@ void llama_context::reset_scheduler() {
bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
if (!cparams.graph_reuse) return false;
if (kv_self.save_per_step_ssm) return false;
if (model.arch == LLM_ARCH_GEMMA4_MTP && mtp_target_ctx != nullptr) return false;
auto the_prev = cparams.mtp_op_type == MTP_OP_NONE ? prev.get() : prev_mtp.get();
if (!the_prev || !the_prev->graph) return false;
//if (u_batch.n_tokens > 1) return false;
@ -810,6 +814,9 @@ static bool llama_kv_cache_init(
const bool is_mtp_tail = qwen_mtp && i >= n_mtp_first;
if (split_cache && !is_mtp_tail) {
buft_layer_count[model.buft_layer[i].buft_matrix]++;
if (model.buft_layer[i].buft != model.buft_layer[i].buft_matrix) {
buft_layer_count[model.buft_layer[i].buft]++;
}
} else {
buft_layer_count[model.buft_layer[i].buft]++;
}
@ -2519,6 +2526,10 @@ static std::pair<std::vector<double>, double> get_layer_sizes(const llama_model_
if (name == "output_norm.weight") {
continue;
}
if (name == "mtp_pre_proj.weight" || name == "mtp_post_proj.weight" ||
name == "mtp_centroids.weight" || name == "mtp_token_ordering.weight") {
continue;
}
auto pos = name.find("blk.");
if (pos != 0) {
LLAMA_LOG_WARN("Oops: tensor with strange name %s\n", name.c_str());
@ -2706,7 +2717,19 @@ static bool llm_load_tensors(
auto & hparams = model.hparams;
if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) {
if (!is_model_split_supported(model)) {
const bool unsupported_gemma_split =
model.arch == LLM_ARCH_GEMMA4_MTP ||
(model.arch == LLM_ARCH_GEMMA4 && hparams.n_embd_per_layer > 0);
if (unsupported_gemma_split) {
LLAMA_LOG_WARN("\n=========================================================\n");
LLAMA_LOG_WARN("Split mode 'graph' is not supported for %s\n",
model.arch == LLM_ARCH_GEMMA4_MTP ? "Gemma 4 MTP assistant"
: "this Gemma4 variant");
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
LLAMA_LOG_WARN("===========================================================\n\n");
split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (!is_model_split_supported(model)) {
LLAMA_LOG_WARN("\n=======================================================\n");
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n");
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
@ -3028,6 +3051,20 @@ static bool llm_load_tensors(
}
}
}
if (model.arch == LLM_ARCH_GEMMA4_MTP && split_mode == LLAMA_SPLIT_MODE_LAYER && device_count > 0 && n_gpu_layers > 0) {
const int mtp_device = std::clamp(main_gpu, 0, device_count - 1);
LLAMA_LOG_INFO("%s: Gemma 4 MTP assistant forcing layer placement to GPU %d under layer split\n",
__func__, mtp_device);
for (int i = i_gpu_start; i < n_layer; ++i) {
model.default_layer_device[i] = mtp_device;
}
if (n_gpu_layers > n_layer) {
model.default_layer_device[n_layer] = mtp_device;
}
}
// assign the repeating layers to the devices according to the splits
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
for (int i = i_gpu_start; i < n_layer; ++i) {
@ -3609,7 +3646,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
#endif
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn && !lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const llama_kv_cache & mask_kv_self =
(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)
? lctx.mtp_target_ctx->kv_self
: kv_self;
const int64_t n_kv = mask_kv_self.n;
const int64_t n_tokens = batch.n_tokens;
@ -3636,21 +3677,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
auto noalibi_f16 = [&mask_kv_self, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
ggml_half h_zero = ggml_fp32_to_fp16(0.f);
for (int i = first; i < last; ++i) {
ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero;
ggml_half h = !mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos ? h_inf : h_zero;
if (data_f16) data_f16[j*n_kv + i] = h;
if (data_swa_f16) {
if (h != h_inf) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
h = h_inf;
}
} else {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
h = h_inf;
}
}
@ -3663,7 +3704,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (n_kv >= 1024 && n_tokens >= 32) {
int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2));
int npt = (n_kv + n_thread - 1)/n_thread;
auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
auto compute = [&batch, &mask_kv_self, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
int first = ith * npt;
int last = std::min(int(n_kv), first + npt);
if (last <= first) return;
@ -3678,11 +3719,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
for (int i = first; i < last; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
if (!mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
f = -std::abs(mask_kv_self.cells[i].pos - pos);
} else {
f = 0.0f;
}
@ -3700,11 +3741,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (f > -INFINITY) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
@ -3759,11 +3800,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
if (!mask_kv_self.cells[i].has_seq_id(seq_id) || mask_kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
f = -std::abs(mask_kv_self.cells[i].pos - pos);
} else {
f = 0.0f;
}
@ -3780,11 +3821,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (data_swa || data_swa_f16) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
if (mask_kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
if (pos - mask_kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
@ -4125,6 +4166,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
static uint32_t llama_output_embd_width(const llama_context & lctx) {
const auto & hparams = lctx.model.hparams;
if (lctx.cparams.mtp && lctx.model.arch == LLM_ARCH_GEMMA4_MTP && hparams.mtp_backbone_n_embd > 0) {
return hparams.mtp_backbone_n_embd;
}
return hparams.n_embd;
}
static bool llama_context_has_mtp_outputs(const llama_context & lctx) {
return lctx.cparams.mtp && (
lctx.model.hparams.nextn_predict_layers > 0 ||
lctx.model.arch == LLM_ARCH_GEMMA4 ||
lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
}
static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
@ -4133,10 +4189,10 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_embd = hparams.n_embd;
const auto n_embd = llama_output_embd_width(lctx);
// TODO: use a per-batch flag for logits presence instead
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.cparams.mtp;
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
const bool has_logits = !cparams.embeddings || has_mtp;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) || has_mtp;
@ -4305,7 +4361,8 @@ static int llama_decode_internal(
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
const bool has_mtp = cparams.mtp && hparams.nextn_predict_layers > 0;
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
const uint32_t n_embd_output = llama_output_embd_width(lctx);
// count outputs
if (batch_all.logits && !embd_pooled) {
@ -4521,7 +4578,8 @@ static int llama_decode_internal(
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
#endif
//if (u_batch.n_tokens == 1 && u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse) {
if (u_batch.embd == nullptr && lctx.cparams.graph_reuse &&
!(lctx.model.arch == LLM_ARCH_GEMMA4_MTP && lctx.mtp_target_ctx != nullptr)) {
prev = std::make_unique<llama_context::Prev>(llama_context::Prev{
(int)u_batch.all_seq_id, (int)lctx.n_outputs, (int)lctx.kv_self.n,
(int)u_batch.n_tokens, cparams.mtp_op_type, gf});
@ -4546,13 +4604,13 @@ static int llama_decode_internal(
res = nullptr;
}
else {
const bool has_mtp = lctx.model.hparams.nextn_predict_layers > 0 && lctx.model.mtp;
const bool use_qwen_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
lctx.model.arch == LLM_ARCH_QWEN35MOE);
const bool has_mtp = llama_context_has_mtp_outputs(lctx);
const bool use_raw_mtp_embd = has_mtp && (lctx.model.arch == LLM_ARCH_QWEN35 ||
lctx.model.arch == LLM_ARCH_QWEN35MOE || lctx.model.arch == LLM_ARCH_GEMMA4 || lctx.model.arch == LLM_ARCH_GEMMA4_MTP);
if (cparams.embeddings || has_mtp) {
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (use_qwen_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) {
// Qwen 3.5 uses raw hidden state before the final shared-head normalization.
if (use_raw_mtp_embd && strcmp(gf->nodes[i]->name, "result_mtp_embd") == 0) {
// MTP recurrent state can be wider/different than the logits head hidden state.
embd = gf->nodes[i];
break;
}
@ -4565,7 +4623,7 @@ static int llama_decode_internal(
}
}
}
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0) {
if (cparams.embeddings && lctx.model.hparams.nextn_predict_layers == 0 && !has_mtp) {
res = nullptr; // do not extract logits for embedding case
} else {
if (!embd) { // do not extract embeddings when not needed
@ -4667,13 +4725,13 @@ static int llama_decode_internal(
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd + n_outputs_prev_embd*n_embd;
const int32_t n_outputs_new_embd = has_mtp ? n_tokens : lctx.n_outputs;
float * embd_out = lctx.embd + n_outputs_prev_embd*n_embd_output;
const int32_t n_outputs_new_embd = has_mtp ? embd->ne[1] : lctx.n_outputs;
if (n_outputs_new_embd) {
GGML_ASSERT( n_outputs_prev_embd + n_outputs_new_embd <= n_outputs_embd);
GGML_ASSERT((n_outputs_prev_embd + n_outputs_new_embd)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new_embd*n_embd*sizeof(float));
GGML_ASSERT((n_outputs_prev_embd + n_outputs_new_embd)*n_embd_output <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new_embd*n_embd_output*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_MEAN:
@ -4704,7 +4762,7 @@ static int llama_decode_internal(
#endif
}
n_outputs_prev += lctx.n_outputs;
n_outputs_prev_embd += has_mtp ? n_tokens : lctx.n_outputs;
n_outputs_prev_embd += (has_mtp && embd) ? embd->ne[1] : lctx.n_outputs;
cur_token += n_tokens;
if (reset_previous) {
// We need to discard this graph. Otherwise, iwith CUDA graphs enabled, the graph will get resused and this will reset the
@ -6033,7 +6091,8 @@ struct llama_context * llama_init_from_model(
}
if (model->arch != LLM_ARCH_GLM4_MOE && model->arch != LLM_ARCH_QWEN35 &&
model->arch != LLM_ARCH_QWEN35MOE && cparams.mtp != 0) {
model->arch != LLM_ARCH_QWEN35MOE && model->arch != LLM_ARCH_GEMMA4 &&
model->arch != LLM_ARCH_GEMMA4_MTP && cparams.mtp != 0) {
cparams.mtp = 0;
}
@ -6572,6 +6631,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_SEED_OSS:
case LLM_ARCH_STEP35:
case LLM_ARCH_GEMMA4:
case LLM_ARCH_GEMMA4_MTP:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
@ -9903,6 +9963,14 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
ctx->draft_input_hidden_state = hidden_state;
}
uint32_t llama_mtp_state_n_embd(const struct llama_context * ctx) {
return llama_output_embd_width(*ctx);
}
void llama_set_mtp_target_context(struct llama_context * ctx, struct llama_context * target_ctx) {
ctx->mtp_target_ctx = target_ctx;
}
size_t llama_fill_from_utf8(void* utf8, void* cpts, void* scripts) {
return unicode_fill_from_utf8((std::string*)utf8, (std::vector<uint32_t>*)cpts, (std::vector<std::string>*)scripts);
}